podman-remote.conf enablement

add the ability for the podman remote client to use a configuration file
which describes its connections. users can now define a connection the
configuration and then call it by name like:

podman-remote -c connection1

and the destination and user will be derived from the configuration
file.  if no -c is provided, we look for a connection in the
configuration file designated as 'default'.  If the configuration file
has only one connection, it will be deemed the 'default'.

Signed-off-by: baude <bbaude@redhat.com>
This commit is contained in:
baude
2019-05-28 09:21:22 -05:00
parent 8a8db34131
commit dc7ae31171
24 changed files with 1017 additions and 332 deletions

View File

@ -33,9 +33,11 @@ type MainFlags struct {
LogLevel string
TmpDir string
RemoteUserName string
RemoteHost string
VarlinkAddress string
RemoteUserName string
RemoteHost string
VarlinkAddress string
ConnectionName string
RemoteConfigFilePath string
}
type AttachValues struct {

View File

@ -9,6 +9,8 @@ import (
const remote = true
func init() {
rootCmd.PersistentFlags().StringVarP(&MainGlobalOpts.ConnectionName, "connection", "c", "", "remote connection name")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteConfigFilePath, "remote-config-path", "", "alternate path for configuration file")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteUserName, "username", "", "username on the remote host")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteHost, "remote-host", "", "remote host")
// TODO maybe we allow the altering of this for bridge connections?

View File

@ -0,0 +1,21 @@
package remoteclientconfig
const remoteConfigFileName string = "podman-remote.conf"
// RemoteConfig describes the podman remote configuration file
type RemoteConfig struct {
Connections map[string]RemoteConnection
}
// RemoteConnection describes the attributes of a podman-remote endpoint
type RemoteConnection struct {
Destination string `toml:"destination"`
Username string `toml:"username"`
IsDefault bool `toml:"default"`
}
// GetConfigFilePath is a simple helper to export the configuration file's
// path based on arch, etc
func GetConfigFilePath() string {
return getConfigFilePath()
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, ".config", "containers", remoteConfigFileName)
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, ".config", "containers", remoteConfigFileName)
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, "AppData", "podman", remoteConfigFileName)
}

View File

@ -0,0 +1,62 @@
package remoteclientconfig
import (
"io"
"github.com/BurntSushi/toml"
"github.com/pkg/errors"
)
// ReadRemoteConfig takes an io.Reader representing the remote configuration
// file and returns a remoteconfig
func ReadRemoteConfig(reader io.Reader) (*RemoteConfig, error) {
var remoteConfig RemoteConfig
// the configuration file does not exist
if reader == nil {
return &remoteConfig, ErrNoConfigationFile
}
_, err := toml.DecodeReader(reader, &remoteConfig)
if err != nil {
return nil, err
}
// We need to validate each remote connection has fields filled out
for name, conn := range remoteConfig.Connections {
if len(conn.Destination) < 1 {
return nil, errors.Errorf("connection %s has no destination defined", name)
}
}
return &remoteConfig, err
}
// GetDefault returns the default RemoteConnection. If there is only one
// connection, we assume it is the default as well
func (r *RemoteConfig) GetDefault() (*RemoteConnection, error) {
if len(r.Connections) == 0 {
return nil, ErrNoDefinedConnections
}
for _, v := range r.Connections {
if len(r.Connections) == 1 {
// if there is only one defined connection, we assume it is
// the default whether tagged as such or not
return &v, nil
}
if v.IsDefault {
return &v, nil
}
}
return nil, ErrNoDefaultConnection
}
// GetRemoteConnection "looks up" a remote connection by name and returns it in the
// form of a RemoteConnection
func (r *RemoteConfig) GetRemoteConnection(name string) (*RemoteConnection, error) {
if len(r.Connections) == 0 {
return nil, ErrNoDefinedConnections
}
for k, v := range r.Connections {
if k == name {
return &v, nil
}
}
return nil, errors.Wrap(ErrConnectionNotFound, name)
}

View File

@ -0,0 +1,201 @@
package remoteclientconfig
import (
"io"
"reflect"
"strings"
"testing"
)
var goodConfig = `
[connections]
[connections.homer]
destination = "192.168.1.1"
username = "myuser"
default = true
[connections.bart]
destination = "foobar.com"
username = "root"
`
var noDest = `
[connections]
[connections.homer]
destination = "192.168.1.1"
username = "myuser"
default = true
[connections.bart]
username = "root"
`
var noUser = `
[connections]
[connections.homer]
destination = "192.168.1.1"
`
func makeGoodResult() *RemoteConfig {
var goodConnections = make(map[string]RemoteConnection)
goodConnections["homer"] = RemoteConnection{
Destination: "192.168.1.1",
Username: "myuser",
IsDefault: true,
}
goodConnections["bart"] = RemoteConnection{
Destination: "foobar.com",
Username: "root",
}
var goodResult = RemoteConfig{
Connections: goodConnections,
}
return &goodResult
}
func makeNoUserResult() *RemoteConfig {
var goodConnections = make(map[string]RemoteConnection)
goodConnections["homer"] = RemoteConnection{
Destination: "192.168.1.1",
}
var goodResult = RemoteConfig{
Connections: goodConnections,
}
return &goodResult
}
func TestReadRemoteConfig(t *testing.T) {
type args struct {
reader io.Reader
}
tests := []struct {
name string
args args
want *RemoteConfig
wantErr bool
}{
// good test should pass
{"good", args{reader: strings.NewReader(goodConfig)}, makeGoodResult(), false},
// a connection with no destination is an error
{"nodest", args{reader: strings.NewReader(noDest)}, nil, true},
// a connnection with no user is OK
{"nouser", args{reader: strings.NewReader(noUser)}, makeNoUserResult(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ReadRemoteConfig(tt.args.reader)
if (err != nil) != tt.wantErr {
t.Errorf("ReadRemoteConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ReadRemoteConfig() = %v, want %v", got, tt.want)
}
})
}
}
func TestRemoteConfig_GetDefault(t *testing.T) {
good := make(map[string]RemoteConnection)
good["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
IsDefault: true,
}
good["bart"] = RemoteConnection{
Username: "root",
Destination: "foobar.com",
}
noDefault := make(map[string]RemoteConnection)
noDefault["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
}
noDefault["bart"] = RemoteConnection{
Username: "root",
Destination: "foobar.com",
}
single := make(map[string]RemoteConnection)
single["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
}
none := make(map[string]RemoteConnection)
type fields struct {
Connections map[string]RemoteConnection
}
tests := []struct {
name string
fields fields
want *RemoteConnection
wantErr bool
}{
// A good toml should return the connection that is marked isDefault
{"good", fields{Connections: makeGoodResult().Connections}, &RemoteConnection{"192.168.1.1", "myuser", true}, false},
// If nothing is marked as isDefault and there is more than one connection, error should occur
{"nodefault", fields{Connections: noDefault}, nil, true},
// if nothing is marked as isDefault but there is only one connection, the one connection is considered the default
{"single", fields{Connections: none}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &RemoteConfig{
Connections: tt.fields.Connections,
}
got, err := r.GetDefault()
if (err != nil) != tt.wantErr {
t.Errorf("RemoteConfig.GetDefault() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RemoteConfig.GetDefault() = %v, want %v", got, tt.want)
}
})
}
}
func TestRemoteConfig_GetRemoteConnection(t *testing.T) {
type fields struct {
Connections map[string]RemoteConnection
}
type args struct {
name string
}
blank := make(map[string]RemoteConnection)
tests := []struct {
name string
fields fields
args args
want *RemoteConnection
wantErr bool
}{
// Good connection
{"goodhomer", fields{Connections: makeGoodResult().Connections}, args{name: "homer"}, &RemoteConnection{"192.168.1.1", "myuser", true}, false},
// Good connection
{"goodbart", fields{Connections: makeGoodResult().Connections}, args{name: "bart"}, &RemoteConnection{"foobar.com", "root", false}, false},
// Getting an unknown connection should result in error
{"noexist", fields{Connections: makeGoodResult().Connections}, args{name: "foobar"}, nil, true},
// Getting a connection when there are none should result in an error
{"none", fields{Connections: blank}, args{name: "foobar"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &RemoteConfig{
Connections: tt.fields.Connections,
}
got, err := r.GetRemoteConnection(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("RemoteConfig.GetRemoteConnection() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RemoteConfig.GetRemoteConnection() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,14 @@
package remoteclientconfig
import "errors"
var (
// ErrNoDefaultConnection no default connection is defined in the podman-remote.conf file
ErrNoDefaultConnection = errors.New("no default connection is defined")
// ErrNoDefinedConnections no connections are defined in the podman-remote.conf file
ErrNoDefinedConnections = errors.New("no remote connections have been defined")
// ErrConnectionNotFound unable to lookup connection by name
ErrConnectionNotFound = errors.New("remote connection not found by name")
// ErrNoConfigationFile no config file found
ErrNoConfigationFile = errors.New("no configuration file found")
)

View File

@ -0,0 +1,47 @@
% podman-remote.conf(5)
## NAME
podman-remote.conf - configuration file for the podman remote client
## DESCRIPTION
The libpod.conf file is the default configuration file for all tools using
libpod to manage containers.
The podman-remote.conf file is the default configuration file for the podman
remote client. It is in the TOML format. It is primarily used to keep track
of the user's remote connections.
## CONNECTION OPTIONS
**destination** = ""
The hostname or IP address of the remote system
**username** = ""
The username to use when connecting to the remote system
**default** = bool
Denotes whether the connection is the default connection for the user. The default connection
is used when the user does not specify a destination or connection name to `podman`.
## EXAMPLE
The following example depicts a configuration file with two connections. One of the connections
is designated as the default connection.
```
[connections]
[connections.host1]
destination = "host1"
username = "homer"
default = true
[connections.host2]
destination = "192.168.122.133"
username = "fedora"
```
## FILES
`/$HOME/.config/containers/podman-remote.conf`, default location for the podman remote
configuration file
## HISTORY
May 2019, Originally compiled by Brent Baude<bbaude@redhat.com>

View File

@ -6,42 +6,52 @@ import (
"fmt"
"os"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/varlink/go/varlink"
)
var remoteEndpoint *Endpoint
func (r RemoteRuntime) RemoteEndpoint() (remoteEndpoint *Endpoint, err error) {
if remoteEndpoint == nil {
remoteEndpoint = &Endpoint{Unknown, ""}
} else {
return remoteEndpoint, nil
}
remoteConfigConnections, _ := remoteclientconfig.ReadRemoteConfig(r.config)
// I'm leaving this here for now as a document of the birdge format. It can be removed later once the bridge
// function is more flushed out.
// bridge := `ssh -T root@192.168.122.1 "/usr/bin/varlink -A '/usr/bin/podman varlink \$VARLINK_ADDRESS' bridge"`
if len(r.cmd.RemoteHost) > 0 {
// The user has provided a remote host endpoint
// If the user defines an env variable for podman_varlink_bridge
// we use that as passed.
if bridge := os.Getenv("PODMAN_VARLINK_BRIDGE"); bridge != "" {
logrus.Debug("creating a varlink bridge based on env variable")
remoteEndpoint, err = newBridgeConnection(bridge, nil, r.cmd.LogLevel)
// if an environment variable for podman_varlink_address is defined,
// we used that as passed
} else if address := os.Getenv("PODMAN_VARLINK_ADDRESS"); address != "" {
logrus.Debug("creating a varlink address based on env variable: %s", address)
remoteEndpoint, err = newSocketConnection(address)
// if the user provides a remote host, we use it to configure a bridge connection
} else if len(r.cmd.RemoteHost) > 0 {
logrus.Debug("creating a varlink bridge based on user input")
if len(r.cmd.RemoteUserName) < 1 {
return nil, errors.New("you must provide a username when providing a remote host name")
}
remoteEndpoint.Type = BridgeConnection
remoteEndpoint.Connection = fmt.Sprintf(
`ssh -T %s@%s /usr/bin/varlink -A \'/usr/bin/podman --log-level=%s varlink \\\$VARLINK_ADDRESS\' bridge`,
r.cmd.RemoteUserName, r.cmd.RemoteHost, r.cmd.LogLevel)
} else if bridge := os.Getenv("PODMAN_VARLINK_BRIDGE"); bridge != "" {
remoteEndpoint.Type = BridgeConnection
remoteEndpoint.Connection = bridge
} else {
address := os.Getenv("PODMAN_VARLINK_ADDRESS")
if address == "" {
address = DefaultAddress
rc := remoteclientconfig.RemoteConnection{r.cmd.RemoteHost, r.cmd.RemoteUserName, false}
remoteEndpoint, err = newBridgeConnection("", &rc, r.cmd.LogLevel)
// if the user has a config file with connections in it
} else if len(remoteConfigConnections.Connections) > 0 {
logrus.Debug("creating a varlink bridge based configuration file")
var rc *remoteclientconfig.RemoteConnection
if len(r.cmd.ConnectionName) > 0 {
rc, err = remoteConfigConnections.GetRemoteConnection(r.cmd.ConnectionName)
} else {
rc, err = remoteConfigConnections.GetDefault()
}
remoteEndpoint.Type = DirectConnection
remoteEndpoint.Connection = address
if err != nil {
return nil, err
}
remoteEndpoint, err = newBridgeConnection("", rc, r.cmd.LogLevel)
// last resort is to make a socket connection with the default varlink address for root user
} else {
logrus.Debug("creating a varlink address based default root address")
remoteEndpoint, err = newSocketConnection(DefaultAddress)
}
return
}
@ -72,3 +82,12 @@ func (r RemoteRuntime) RefreshConnection() error {
r.Conn = newConn
return nil
}
// newSocketConnection returns an endpoint for a uds based connection
func newSocketConnection(address string) (*Endpoint, error) {
endpoint := Endpoint{
Type: DirectConnection,
Connection: address,
}
return &endpoint, nil
}

View File

@ -0,0 +1,30 @@
// +build linux darwin
// +build remoteclient
package adapter
import (
"fmt"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/pkg/errors"
)
// newBridgeConnection creates a bridge type endpoint with username, destination, and log-level
func newBridgeConnection(formattedBridge string, remoteConn *remoteclientconfig.RemoteConnection, logLevel string) (*Endpoint, error) {
endpoint := Endpoint{
Type: BridgeConnection,
}
if len(formattedBridge) < 1 && remoteConn == nil {
return nil, errors.New("bridge connections must either be created by string or remoteconnection")
}
if len(formattedBridge) > 0 {
endpoint.Connection = formattedBridge
return &endpoint, nil
}
endpoint.Connection = fmt.Sprintf(
`ssh -T %s@%s -- /usr/bin/varlink -A \'/usr/bin/podman --log-level=%s varlink \\\$VARLINK_ADDRESS\' bridge`,
remoteConn.Username, remoteConn.Destination, logLevel)
return &endpoint, nil
}

View File

@ -0,0 +1,15 @@
// +build remoteclient
package adapter
import (
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/containers/libpod/libpod"
)
func newBridgeConnection(formattedBridge string, remoteConn *remoteclientconfig.RemoteConnection, logLevel string) (*Endpoint, error) {
// TODO
// Unix and Windows appear to quote their ssh implementations differently therefore once we figure out what
// windows ssh is doing here, we can then get the format correct.
return nil, libpod.ErrNotImplemented
}

View File

@ -20,6 +20,7 @@ import (
"github.com/containers/image/docker/reference"
"github.com/containers/image/types"
"github.com/containers/libpod/cmd/podman/cliconfig"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/containers/libpod/cmd/podman/varlink"
"github.com/containers/libpod/libpod"
"github.com/containers/libpod/libpod/events"
@ -40,6 +41,7 @@ type RemoteRuntime struct {
Conn *varlink.Connection
Remote bool
cmd cliconfig.MainFlags
config io.Reader
}
// LocalRuntime describes a typical libpod runtime
@ -49,10 +51,35 @@ type LocalRuntime struct {
// GetRuntime returns a LocalRuntime struct with the actual runtime embedded in it
func GetRuntime(ctx context.Context, c *cliconfig.PodmanCommand) (*LocalRuntime, error) {
var (
customConfig bool
err error
f *os.File
)
runtime := RemoteRuntime{
Remote: true,
cmd: c.GlobalFlags,
}
configPath := remoteclientconfig.GetConfigFilePath()
if len(c.GlobalFlags.RemoteConfigFilePath) > 0 {
configPath = c.GlobalFlags.RemoteConfigFilePath
customConfig = true
}
f, err = os.Open(configPath)
if err != nil {
// If user does not explicitly provide a configuration file path and we cannot
// find a default, no error should occur.
if os.IsNotExist(err) && !customConfig {
logrus.Debugf("unable to load configuration file at %s", configPath)
runtime.config = nil
} else {
return nil, errors.Wrapf(err, "unable to load configuration file at %s", configPath)
}
} else {
// create the io reader for the remote client
runtime.config = bufio.NewReader(f)
}
conn, err := runtime.Connect()
if err != nil {
return nil, err

View File

@ -3,7 +3,7 @@
#
# TODO: no release, can we find an alternative?
github.com/Azure/go-ansiterm d6e3b3328b783f23731bc4d058875b0371ff8109
github.com/BurntSushi/toml v0.2.0
github.com/BurntSushi/toml v0.3.1
github.com/Microsoft/go-winio v0.4.11
github.com/Microsoft/hcsshim v0.8.3
github.com/blang/semver v3.5.0

View File

@ -1,14 +1,21 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
Version 2, December 2004
The MIT License (MIT)
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net>
Copyright (c) 2013 TOML authors
Everyone is permitted to copy and distribute verbatim or modified
copies of this license document, and changing it is allowed as long
as the name is changed.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. You just DO WHAT THE FUCK YOU WANT TO.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -6,12 +6,12 @@ packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
Spec: https://github.com/mojombo/toml
Spec: https://github.com/toml-lang/toml
Compatible with TOML version
[v0.2.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.2.0.md)
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml
Documentation: https://godoc.org/github.com/BurntSushi/toml
Installation:
@ -26,8 +26,7 @@ go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build status](https://api.travis-ci.org/BurntSushi/toml.png)](https://travis-ci.org/BurntSushi/toml)
[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
### Testing
@ -217,4 +216,3 @@ Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@ -10,7 +10,9 @@ import (
"time"
)
var e = fmt.Errorf
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
@ -103,6 +105,13 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data)
if err != nil {
return MetaData{}, err
@ -111,7 +120,7 @@ func Decode(data string, v interface{}) (MetaData, error) {
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, rvalue(v))
return md, md.unify(p.mapping, indirect(rv))
}
// DecodeFile is just like Decode, except it will automatically read the
@ -211,7 +220,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Interface:
// we only support empty interfaces.
if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind())
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
@ -219,7 +228,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("Unsupported type '%s'.", rv.Kind())
return e("unsupported type %s", rv.Kind())
}
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
@ -228,7 +237,8 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if mapping == nil {
return nil
}
return mismatch(rv, "map", mapping)
return e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
}
for key, datum := range tmap {
@ -253,14 +263,13 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s",
rv.Type().String(), f.name, err)
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+
"be loaded with reflection.", rv.Type().String(), f.name)
return e("cannot write unexported field %s.%s",
rv.Type().String(), f.name)
}
}
}
@ -378,15 +387,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
// No bounds checking necessary.
case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num)
return e("value %d is out of range for int8", num)
}
case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num)
return e("value %d is out of range for int16", num)
}
case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num)
return e("value %d is out of range for int32", num)
}
}
rv.SetInt(num)
@ -397,15 +406,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
// No bounds checking necessary.
case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num)
return e("value %d is out of range for uint8", num)
}
case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num)
return e("value %d is out of range for uint16", num)
}
case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num)
return e("value %d is out of range for uint32", num)
}
}
rv.SetUint(unum)
@ -471,7 +480,7 @@ func rvalue(v interface{}) reflect.Value {
// interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanAddr() {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv
@ -496,10 +505,5 @@ func isUnifiable(rv reflect.Value) bool {
}
func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
return e("cannot load TOML value of type %T into a Go %s", data, expected)
}

View File

@ -77,9 +77,8 @@ func (k Key) maybeQuoted(i int) string {
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
}
return k[i]
}
func (k Key) add(piece string) Key {

View File

@ -4,7 +4,7 @@ files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
The specification implemented: https://github.com/mojombo/toml
The specification implemented: https://github.com/toml-lang/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the

View File

@ -16,17 +16,17 @@ type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types")
"toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New(
"can't encode array with nil element")
"toml: cannot encode array with nil element")
errNonString = errors.New(
"can't encode a map with non-string key type")
"toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct")
"toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"TOML array element can't contain a table")
"toml: TOML array element cannot contain a table")
errNoKey = errors.New(
"top-level values must be a Go map or struct")
"toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)
@ -148,7 +148,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("Unsupported type for key '%s': %s", key, k))
panic(e("unsupported type for key '%s': %s", key, k))
}
}
@ -160,7 +160,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z"))
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
@ -191,7 +191,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.String:
enc.writeQuoted(rv.String())
default:
panic(e("Unexpected primitive type: %s", rv.Kind()))
panic(e("unexpected primitive type: %s", rv.Kind()))
}
}
@ -241,7 +241,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra new line between top-level tables.
// Output an extra newline between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
@ -315,10 +315,16 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
t := f.Type
switch t.Kind() {
case reflect.Struct:
addFields(t, frv, f.Index)
continue
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index)
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct {
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index)
}
@ -347,17 +353,18 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
continue
}
tag := sft.Tag.Get("toml")
if tag == "-" {
opts := getOptions(sft.Tag)
if opts.skip {
continue
}
keyName, opts := getOptions(tag)
if keyName == "" {
keyName = sft.Name
keyName := sft.Name
if opts.name != "" {
keyName = opts.name
}
if _, ok := opts["omitempty"]; ok && isEmpty(sf) {
if opts.omitempty && isEmpty(sf) {
continue
} else if _, ok := opts["omitzero"]; ok && isZero(sf) {
}
if opts.omitzero && isZero(sf) {
continue
}
@ -392,9 +399,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash
} else {
return tomlArray
}
return tomlArray
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
@ -451,17 +457,30 @@ func tomlArrayType(rv reflect.Value) tomlType {
return firstType
}
func getOptions(keyName string) (string, map[string]struct{}) {
opts := make(map[string]struct{})
ss := strings.Split(keyName, ",")
name := ss[0]
if len(ss) > 1 {
for _, opt := range ss {
opts[opt] = struct{}{}
type tagOptions struct {
skip bool // "-"
name string
omitempty bool
omitzero bool
}
func getOptions(tag reflect.StructTag) tagOptions {
t := tag.Get("toml")
if t == "-" {
return tagOptions{skip: true}
}
var opts tagOptions
parts := strings.Split(t, ",")
opts.name = parts[0]
for _, s := range parts[1:] {
switch s {
case "omitempty":
opts.omitempty = true
case "omitzero":
opts.omitzero = true
}
}
return name, opts
return opts
}
func isZero(rv reflect.Value) bool {

View File

@ -3,6 +3,7 @@ package toml
import (
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
@ -29,24 +30,28 @@ const (
itemArrayTableEnd
itemKeyStart
itemCommentStart
itemInlineTableStart
itemInlineTableEnd
)
const (
eof = 0
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
arrayValTerm = ','
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
eof = 0
comma = ','
tableStart = '['
tableEnd = ']'
arrayTableStart = '['
arrayTableEnd = ']'
tableSep = '.'
keySep = '='
arrayStart = '['
arrayEnd = ']'
commentStart = '#'
stringStart = '"'
stringEnd = '"'
rawStringStart = '\''
rawStringEnd = '\''
inlineTableStart = '{'
inlineTableEnd = '}'
)
type stateFn func(lx *lexer) stateFn
@ -55,11 +60,18 @@ type lexer struct {
input string
start int
pos int
width int
line int
state stateFn
items chan item
// Allow for backing up up to three runes.
// This is necessary because TOML contains 3-rune tokens (""" and ''').
prevWidths [3]int
nprev int // how many of prevWidths are in use
// If we emit an eof, we can still back up, but it is not OK to call
// next again.
atEOF bool
// A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily
@ -87,7 +99,7 @@ func (lx *lexer) nextItem() item {
func lex(input string) *lexer {
lx := &lexer{
input: input + "\n",
input: input,
state: lexTop,
line: 1,
items: make(chan item, 10),
@ -102,7 +114,7 @@ func (lx *lexer) push(state stateFn) {
func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.")
return lx.errorf("BUG in lexer: no states to pop")
}
last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1]
@ -124,16 +136,25 @@ func (lx *lexer) emitTrim(typ itemType) {
}
func (lx *lexer) next() (r rune) {
if lx.atEOF {
panic("next called after EOF")
}
if lx.pos >= len(lx.input) {
lx.width = 0
lx.atEOF = true
return eof
}
if lx.input[lx.pos] == '\n' {
lx.line++
}
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.pos += lx.width
lx.prevWidths[2] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[0]
if lx.nprev < 3 {
lx.nprev++
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.prevWidths[0] = w
lx.pos += w
return r
}
@ -142,9 +163,20 @@ func (lx *lexer) ignore() {
lx.start = lx.pos
}
// backup steps back one rune. Can be called only once per call of next.
// backup steps back one rune. Can be called only twice between calls to next.
func (lx *lexer) backup() {
lx.pos -= lx.width
if lx.atEOF {
lx.atEOF = false
return
}
if lx.nprev < 1 {
panic("backed up too far")
}
w := lx.prevWidths[0]
lx.prevWidths[0] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[2]
lx.nprev--
lx.pos -= w
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line--
}
@ -166,9 +198,22 @@ func (lx *lexer) peek() rune {
return r
}
// skip ignores all input that matches the given predicate.
func (lx *lexer) skip(pred func(rune) bool) {
for {
r := lx.next()
if pred(r) {
continue
}
lx.backup()
lx.ignore()
return
}
}
// errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special
// character (new lines, tabs, etc.).
// character (newlines, tabs, etc.).
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
lx.items <- item{
itemError,
@ -184,7 +229,6 @@ func lexTop(lx *lexer) stateFn {
if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop)
}
switch r {
case commentStart:
lx.push(lexTop)
@ -193,7 +237,7 @@ func lexTop(lx *lexer) stateFn {
return lexTableStart
case eof:
if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.")
return lx.errorf("unexpected EOF")
}
lx.emit(itemEOF)
return nil
@ -208,12 +252,12 @@ func lexTop(lx *lexer) stateFn {
// lexTopEnd is entered whenever a top-level item has been consumed. (A value
// or a table.) It must see only whitespace, and will turn back to lexTop
// upon a new line. If it sees EOF, it will quit the lexer successfully.
// upon a newline. If it sees EOF, it will quit the lexer successfully.
func lexTopEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case r == commentStart:
// a comment will read to a new line for us.
// a comment will read to a newline for us.
lx.push(lexTop)
return lexCommentStart
case isWhitespace(r):
@ -222,11 +266,11 @@ func lexTopEnd(lx *lexer) stateFn {
lx.ignore()
return lexTop
case r == eof:
lx.ignore()
return lexTop
lx.emit(itemEOF)
return nil
}
return lx.errorf("Expected a top-level item to end with a new line, "+
"comment or EOF, but got %q instead.", r)
return lx.errorf("expected a top-level item to end with a newline, "+
"comment, or EOF, but got %q instead", r)
}
// lexTable lexes the beginning of a table. Namely, it makes sure that
@ -253,21 +297,22 @@ func lexTableEnd(lx *lexer) stateFn {
func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r)
return lx.errorf("expected end of table array name delimiter %q, "+
"but got %q instead", arrayTableEnd, r)
}
lx.emit(itemArrayTableEnd)
return lexTopEnd
}
func lexTableNameStart(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.peek(); {
case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " +
"be empty.)")
return lx.errorf("unexpected end of table name " +
"(table names cannot be empty)")
case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " +
"be empty.)")
return lx.errorf("unexpected table separator " +
"(table names cannot be empty)")
case r == stringStart || r == rawStringStart:
lx.ignore()
lx.push(lexTableNameEnd)
@ -277,24 +322,22 @@ func lexTableNameStart(lx *lexer) stateFn {
}
}
// lexTableName lexes the name of a table. It assumes that at least one
// lexBareTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); {
case isBareKeyChar(r):
r := lx.next()
if isBareKeyChar(r) {
return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
}
lx.backup()
lx.emit(itemText)
return lexTableNameEnd
}
// lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.next(); {
case isWhitespace(r):
return lexTableNameEnd
@ -304,8 +347,8 @@ func lexTableNameEnd(lx *lexer) stateFn {
case r == tableEnd:
return lx.pop()
default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+
"instead.", r)
return lx.errorf("expected '.' or ']' to end table name, "+
"but got %q instead", r)
}
}
@ -315,7 +358,7 @@ func lexKeyStart(lx *lexer) stateFn {
r := lx.peek()
switch {
case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep)
return lx.errorf("unexpected key separator %q", keySep)
case isWhitespace(r) || isNL(r):
lx.next()
return lexSkip(lx, lexKeyStart)
@ -338,14 +381,15 @@ func lexBareKey(lx *lexer) stateFn {
case isBareKeyChar(r):
return lexBareKey
case isWhitespace(r):
lx.emitTrim(itemText)
lx.backup()
lx.emit(itemText)
return lexKeyEnd
case r == keySep:
lx.backup()
lx.emitTrim(itemText)
lx.emit(itemText)
return lexKeyEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
return lx.errorf("bare keys cannot contain %q", r)
}
}
@ -358,7 +402,7 @@ func lexKeyEnd(lx *lexer) stateFn {
case isWhitespace(r):
return lexSkip(lx, lexKeyEnd)
default:
return lx.errorf("Expected key separator %q, but got %q instead.",
return lx.errorf("expected key separator %q, but got %q instead",
keySep, r)
}
}
@ -367,20 +411,26 @@ func lexKeyEnd(lx *lexer) stateFn {
// lexValue will ignore whitespace.
// After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT new lines.
// In array syntax, the array states are responsible for ignoring new
// lines.
// We allow whitespace to precede a value, but NOT newlines.
// In array syntax, the array states are responsible for ignoring newlines.
r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch {
case r == arrayStart:
case isWhitespace(r):
return lexSkip(lx, lexValue)
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
}
switch r {
case arrayStart:
lx.ignore()
lx.emit(itemArray)
return lexArrayValue
case r == stringStart:
case inlineTableStart:
lx.ignore()
lx.emit(itemInlineTableStart)
return lexInlineTableValue
case stringStart:
if lx.accept(stringStart) {
if lx.accept(stringStart) {
lx.ignore() // Ignore """
@ -390,7 +440,7 @@ func lexValue(lx *lexer) stateFn {
}
lx.ignore() // ignore the '"'
return lexString
case r == rawStringStart:
case rawStringStart:
if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) {
lx.ignore() // Ignore """
@ -400,23 +450,24 @@ func lexValue(lx *lexer) stateFn {
}
lx.ignore() // ignore the "'"
return lexRawString
case r == 't':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
case '+', '-':
return lexNumberStart
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
case '.': // special error case, be kind to users
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("Expected value but found %q instead.", r)
if unicode.IsLetter(r) {
// Be permissive here; lexBool will give a nice error if the
// user wrote something like
// x = foo
// (i.e. not 'true' or 'false' but is something else word-like.)
lx.backup()
return lexBool
}
return lx.errorf("expected value but found %q instead", r)
}
// lexArrayValue consumes one value in an array. It assumes that '[' or ','
// have already been consumed. All whitespace and new lines are ignored.
// have already been consumed. All whitespace and newlines are ignored.
func lexArrayValue(lx *lexer) stateFn {
r := lx.next()
switch {
@ -425,10 +476,11 @@ func lexArrayValue(lx *lexer) stateFn {
case r == commentStart:
lx.push(lexArrayValue)
return lexCommentStart
case r == arrayValTerm:
return lx.errorf("Unexpected array value terminator %q.",
arrayValTerm)
case r == comma:
return lx.errorf("unexpected comma")
case r == arrayEnd:
// NOTE(caleb): The spec isn't clear about whether you can have
// a trailing comma or not, so we'll allow it.
return lexArrayEnd
}
@ -437,8 +489,9 @@ func lexArrayValue(lx *lexer) stateFn {
return lexValue
}
// lexArrayValueEnd consumes the cruft between values of an array. Namely,
// it ignores whitespace and expects either a ',' or a ']'.
// lexArrayValueEnd consumes everything between the end of an array value and
// the next value (or the end of the array): it ignores whitespace and newlines
// and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
@ -447,31 +500,88 @@ func lexArrayValueEnd(lx *lexer) stateFn {
case r == commentStart:
lx.push(lexArrayValueEnd)
return lexCommentStart
case r == arrayValTerm:
case r == comma:
lx.ignore()
return lexArrayValue // move on to the next value
case r == arrayEnd:
return lexArrayEnd
}
return lx.errorf("Expected an array value terminator %q or an array "+
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r)
return lx.errorf(
"expected a comma or array terminator %q, but got %q instead",
arrayEnd, r,
)
}
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has
// just been consumed.
// lexArrayEnd finishes the lexing of an array.
// It assumes that a ']' has just been consumed.
func lexArrayEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemArrayEnd)
return lx.pop()
}
// lexInlineTableValue consumes one key/value pair in an inline table.
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
func lexInlineTableValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == inlineTableEnd:
return lexInlineTableEnd
}
lx.backup()
lx.push(lexInlineTableValueEnd)
return lexKeyStart
}
// lexInlineTableValueEnd consumes everything between the end of an inline table
// key/value pair and the next pair (or the end of the table):
// it ignores whitespace and expects either a ',' or a '}'.
func lexInlineTableValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexInlineTableValue
case r == inlineTableEnd:
return lexInlineTableEnd
}
return lx.errorf("expected a comma or an inline table terminator %q, "+
"but got %q instead", inlineTableEnd, r)
}
// lexInlineTableEnd finishes the lexing of an inline table.
// It assumes that a '}' has just been consumed.
func lexInlineTableEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemInlineTableEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
return lx.errorf("strings cannot contain newlines")
case r == '\\':
lx.push(lexString)
return lexStringEscape
@ -488,11 +598,12 @@ func lexString(lx *lexer) stateFn {
// lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '\\':
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case '\\':
return lexMultilineStringEscape
case r == stringEnd:
case stringEnd:
if lx.accept(stringEnd) {
if lx.accept(stringEnd) {
lx.backup()
@ -516,8 +627,10 @@ func lexMultilineString(lx *lexer) stateFn {
func lexRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r):
return lx.errorf("Strings cannot contain new lines.")
return lx.errorf("strings cannot contain newlines")
case r == rawStringEnd:
lx.backup()
lx.emit(itemRawString)
@ -529,12 +642,13 @@ func lexRawString(lx *lexer) stateFn {
}
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and
// a string. It assumes that the beginning "'''" has already been consumed and
// ignored.
func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next()
switch {
case r == rawStringEnd:
switch lx.next() {
case eof:
return lx.errorf("unexpected EOF")
case rawStringEnd:
if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) {
lx.backup()
@ -559,11 +673,10 @@ func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first:
if isNL(lx.next()) {
return lexMultilineString
} else {
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
lx.backup()
lx.push(lexMultilineString)
return lexStringEscape(lx)
}
func lexStringEscape(lx *lexer) stateFn {
@ -588,10 +701,9 @@ func lexStringEscape(lx *lexer) stateFn {
case 'U':
return lexLongUnicodeEscape
}
return lx.errorf("Invalid escape character %q. Only the following "+
return lx.errorf("invalid escape character %q; only the following "+
"escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+
"\\uXXXX and \\UXXXXXXXX.", r)
`\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
}
func lexShortUnicodeEscape(lx *lexer) stateFn {
@ -599,8 +711,8 @@ func lexShortUnicodeEscape(lx *lexer) stateFn {
for i := 0; i < 4; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+
"but got '%s' instead.", lx.current())
return lx.errorf(`expected four hexadecimal digits after '\u', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
@ -611,40 +723,43 @@ func lexLongUnicodeEscape(lx *lexer) stateFn {
for i := 0; i < 8; i++ {
r = lx.next()
if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+
"but got '%s' instead.", lx.current())
return lx.errorf(`expected eight hexadecimal digits after '\U', `+
"but got %q instead", lx.current())
}
}
return lx.pop()
}
// lexNumberOrDateStart consumes either a (positive) integer, float or
// datetime. It assumes that NO negative sign has been consumed.
// lexNumberOrDateStart consumes either an integer, a float, or datetime.
func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
if isDigit(r) {
return lexNumberOrDate
}
return lexNumberOrDate
switch r {
case '_':
return lexNumber
case 'e', 'E':
return lexFloat
case '.':
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime.
// lexNumberOrDate consumes either an integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next()
switch {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
if isDigit(r) {
return lexNumberOrDate
case r == '.':
return lexFloatStart
}
switch r {
case '-':
return lexDatetime
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
@ -652,46 +767,34 @@ func lexNumberOrDate(lx *lexer) stateFn {
return lx.pop()
}
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format.
// It assumes that "YYYY-" has already been consumed.
func lexDateAfterYear(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
// lexDatetime consumes a Datetime, to a first approximation.
// The parser validates that it matches one of the accepted formats.
func lexDatetime(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexDatetime
}
for _, f := range formats {
r := lx.next()
if f == '0' {
if !isDigit(r) {
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
}
switch r {
case '-', 'T', ':', '.', 'Z', '+':
return lexDatetime
}
lx.backup()
lx.emit(itemDatetime)
return lx.pop()
}
// lexNumberStart consumes either an integer or a float. It assumes that
// a negative sign has already been read, but that *no* digits have been
// consumed. lexNumberStart will move to the appropriate integer or float
// states.
// lexNumberStart consumes either an integer or a float. It assumes that a sign
// has already been read, but that *no* digits have been consumed.
// lexNumberStart will move to the appropriate integer or float states.
func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit.
// We MUST see a digit. Even floats have to start with a digit.
r := lx.next()
if !isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
return lexNumber
}
@ -699,11 +802,14 @@ func lexNumberStart(lx *lexer) stateFn {
// lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn {
r := lx.next()
switch {
case isDigit(r):
if isDigit(r) {
return lexNumber
case r == '.':
return lexFloatStart
}
switch r {
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
}
lx.backup()
@ -711,60 +817,42 @@ func lexNumber(lx *lexer) stateFn {
return lx.pop()
}
// lexFloatStart starts the consumption of digits of a float after a '.'.
// Namely, at least one digit is required.
func lexFloatStart(lx *lexer) stateFn {
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
// lexFloat consumes the elements of a float. It allows any sequence of
// float-like characters, so floats emitted by the lexer are only a first
// approximation and must be validated by the parser.
func lexFloat(lx *lexer) stateFn {
r := lx.next()
if isDigit(r) {
return lexFloat
}
switch r {
case '_', '.', '-', '+', 'e', 'E':
return lexFloat
}
lx.backup()
lx.emit(itemFloat)
return lx.pop()
}
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been
// consumed.
func lexConst(lx *lexer, s string) stateFn {
for i := range s[1:] {
if r := lx.next(); r != rune(s[i+1]) {
return lx.errorf("Expected %q, but found %q instead.", s[:i+1],
s[:i]+string(r))
// lexBool consumes a bool string: 'true' or 'false.
func lexBool(lx *lexer) stateFn {
var rs []rune
for {
r := lx.next()
if !unicode.IsLetter(r) {
lx.backup()
break
}
rs = append(rs, r)
}
return nil
}
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
s := string(rs)
switch s {
case "true", "false":
lx.emit(itemBool)
return lx.pop()
}
lx.emit(itemBool)
return lx.pop()
}
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
return lx.errorf("expected value but found %q instead", s)
}
// lexCommentStart begins the lexing of a comment. It will emit
@ -776,7 +864,7 @@ func lexCommentStart(lx *lexer) stateFn {
}
// lexComment lexes an entire comment. It assumes that '#' has been consumed.
// It will consume *up to* the first new line character, and pass control
// It will consume *up to* the first newline character, and pass control
// back to the last state on the stack.
func lexComment(lx *lexer) stateFn {
r := lx.peek()
@ -834,13 +922,7 @@ func (itype itemType) String() string {
return "EOF"
case itemText:
return "Text"
case itemString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemBool:
return "Bool"

View File

@ -2,7 +2,6 @@ package toml
import (
"fmt"
"log"
"strconv"
"strings"
"time"
@ -81,7 +80,7 @@ func (p *parser) next() item {
}
func (p *parser) bug(format string, v ...interface{}) {
log.Panicf("BUG: %s\n\n", fmt.Sprintf(format, v...))
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
func (p *parser) expect(typ itemType) item {
@ -179,10 +178,18 @@ func (p *parser) value(it item) (interface{}, tomlType) {
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64)
if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil {
// See comment below for floats describing why we make a
// distinction between a bug and a user error.
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
@ -194,29 +201,57 @@ func (p *parser) value(it item) (interface{}, tomlType) {
}
return num, p.typeOfPrimitive(it)
case itemFloat:
num, err := strconv.ParseFloat(it.val, 64)
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.bug("Expected float value, but got '%s'.", it.val)
p.panicf("Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val)
if err != nil {
p.panicf("Invalid RFC3339 Zulu DateTime: '%s'.", it.val)
var t time.Time
var ok bool
var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
case itemArray:
@ -234,11 +269,75 @@ func (p *parser) value(it item) (interface{}, tomlType) {
types = append(types, typ)
}
return array, p.typeOfArray(types)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
}
return accept
}
// numPeriodsOK checks whether every period in s is followed by a digit.
func numPeriodsOK(s string) bool {
period := false
for _, r := range s {
if period && !isDigit(r) {
return false
}
period = r == '.'
}
return !period
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.

View File

@ -95,8 +95,8 @@ func typeFields(t reflect.Type) []field {
if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue
}
name, _ := getOptions(sf.Tag.Get("toml"))
if name == "-" {
opts := getOptions(sf.Tag)
if opts.skip {
continue
}
index := make([]int, len(f.index)+1)
@ -110,8 +110,9 @@ func typeFields(t reflect.Type) []field {
}
// Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != ""
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := opts.name != ""
name := opts.name
if name == "" {
name = sf.Name
}