podman-remote.conf enablement

add the ability for the podman remote client to use a configuration file
which describes its connections. users can now define a connection the
configuration and then call it by name like:

podman-remote -c connection1

and the destination and user will be derived from the configuration
file.  if no -c is provided, we look for a connection in the
configuration file designated as 'default'.  If the configuration file
has only one connection, it will be deemed the 'default'.

Signed-off-by: baude <bbaude@redhat.com>
This commit is contained in:
baude
2019-05-28 09:21:22 -05:00
parent 8a8db34131
commit dc7ae31171
24 changed files with 1017 additions and 332 deletions

View File

@ -36,6 +36,8 @@ type MainFlags struct {
RemoteUserName string RemoteUserName string
RemoteHost string RemoteHost string
VarlinkAddress string VarlinkAddress string
ConnectionName string
RemoteConfigFilePath string
} }
type AttachValues struct { type AttachValues struct {

View File

@ -9,6 +9,8 @@ import (
const remote = true const remote = true
func init() { func init() {
rootCmd.PersistentFlags().StringVarP(&MainGlobalOpts.ConnectionName, "connection", "c", "", "remote connection name")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteConfigFilePath, "remote-config-path", "", "alternate path for configuration file")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteUserName, "username", "", "username on the remote host") rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteUserName, "username", "", "username on the remote host")
rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteHost, "remote-host", "", "remote host") rootCmd.PersistentFlags().StringVar(&MainGlobalOpts.RemoteHost, "remote-host", "", "remote host")
// TODO maybe we allow the altering of this for bridge connections? // TODO maybe we allow the altering of this for bridge connections?

View File

@ -0,0 +1,21 @@
package remoteclientconfig
const remoteConfigFileName string = "podman-remote.conf"
// RemoteConfig describes the podman remote configuration file
type RemoteConfig struct {
Connections map[string]RemoteConnection
}
// RemoteConnection describes the attributes of a podman-remote endpoint
type RemoteConnection struct {
Destination string `toml:"destination"`
Username string `toml:"username"`
IsDefault bool `toml:"default"`
}
// GetConfigFilePath is a simple helper to export the configuration file's
// path based on arch, etc
func GetConfigFilePath() string {
return getConfigFilePath()
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, ".config", "containers", remoteConfigFileName)
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, ".config", "containers", remoteConfigFileName)
}

View File

@ -0,0 +1,12 @@
package remoteclientconfig
import (
"path/filepath"
"github.com/docker/docker/pkg/homedir"
)
func getConfigFilePath() string {
homeDir := homedir.Get()
return filepath.Join(homeDir, "AppData", "podman", remoteConfigFileName)
}

View File

@ -0,0 +1,62 @@
package remoteclientconfig
import (
"io"
"github.com/BurntSushi/toml"
"github.com/pkg/errors"
)
// ReadRemoteConfig takes an io.Reader representing the remote configuration
// file and returns a remoteconfig
func ReadRemoteConfig(reader io.Reader) (*RemoteConfig, error) {
var remoteConfig RemoteConfig
// the configuration file does not exist
if reader == nil {
return &remoteConfig, ErrNoConfigationFile
}
_, err := toml.DecodeReader(reader, &remoteConfig)
if err != nil {
return nil, err
}
// We need to validate each remote connection has fields filled out
for name, conn := range remoteConfig.Connections {
if len(conn.Destination) < 1 {
return nil, errors.Errorf("connection %s has no destination defined", name)
}
}
return &remoteConfig, err
}
// GetDefault returns the default RemoteConnection. If there is only one
// connection, we assume it is the default as well
func (r *RemoteConfig) GetDefault() (*RemoteConnection, error) {
if len(r.Connections) == 0 {
return nil, ErrNoDefinedConnections
}
for _, v := range r.Connections {
if len(r.Connections) == 1 {
// if there is only one defined connection, we assume it is
// the default whether tagged as such or not
return &v, nil
}
if v.IsDefault {
return &v, nil
}
}
return nil, ErrNoDefaultConnection
}
// GetRemoteConnection "looks up" a remote connection by name and returns it in the
// form of a RemoteConnection
func (r *RemoteConfig) GetRemoteConnection(name string) (*RemoteConnection, error) {
if len(r.Connections) == 0 {
return nil, ErrNoDefinedConnections
}
for k, v := range r.Connections {
if k == name {
return &v, nil
}
}
return nil, errors.Wrap(ErrConnectionNotFound, name)
}

View File

@ -0,0 +1,201 @@
package remoteclientconfig
import (
"io"
"reflect"
"strings"
"testing"
)
var goodConfig = `
[connections]
[connections.homer]
destination = "192.168.1.1"
username = "myuser"
default = true
[connections.bart]
destination = "foobar.com"
username = "root"
`
var noDest = `
[connections]
[connections.homer]
destination = "192.168.1.1"
username = "myuser"
default = true
[connections.bart]
username = "root"
`
var noUser = `
[connections]
[connections.homer]
destination = "192.168.1.1"
`
func makeGoodResult() *RemoteConfig {
var goodConnections = make(map[string]RemoteConnection)
goodConnections["homer"] = RemoteConnection{
Destination: "192.168.1.1",
Username: "myuser",
IsDefault: true,
}
goodConnections["bart"] = RemoteConnection{
Destination: "foobar.com",
Username: "root",
}
var goodResult = RemoteConfig{
Connections: goodConnections,
}
return &goodResult
}
func makeNoUserResult() *RemoteConfig {
var goodConnections = make(map[string]RemoteConnection)
goodConnections["homer"] = RemoteConnection{
Destination: "192.168.1.1",
}
var goodResult = RemoteConfig{
Connections: goodConnections,
}
return &goodResult
}
func TestReadRemoteConfig(t *testing.T) {
type args struct {
reader io.Reader
}
tests := []struct {
name string
args args
want *RemoteConfig
wantErr bool
}{
// good test should pass
{"good", args{reader: strings.NewReader(goodConfig)}, makeGoodResult(), false},
// a connection with no destination is an error
{"nodest", args{reader: strings.NewReader(noDest)}, nil, true},
// a connnection with no user is OK
{"nouser", args{reader: strings.NewReader(noUser)}, makeNoUserResult(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ReadRemoteConfig(tt.args.reader)
if (err != nil) != tt.wantErr {
t.Errorf("ReadRemoteConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ReadRemoteConfig() = %v, want %v", got, tt.want)
}
})
}
}
func TestRemoteConfig_GetDefault(t *testing.T) {
good := make(map[string]RemoteConnection)
good["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
IsDefault: true,
}
good["bart"] = RemoteConnection{
Username: "root",
Destination: "foobar.com",
}
noDefault := make(map[string]RemoteConnection)
noDefault["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
}
noDefault["bart"] = RemoteConnection{
Username: "root",
Destination: "foobar.com",
}
single := make(map[string]RemoteConnection)
single["homer"] = RemoteConnection{
Username: "myuser",
Destination: "192.168.1.1",
}
none := make(map[string]RemoteConnection)
type fields struct {
Connections map[string]RemoteConnection
}
tests := []struct {
name string
fields fields
want *RemoteConnection
wantErr bool
}{
// A good toml should return the connection that is marked isDefault
{"good", fields{Connections: makeGoodResult().Connections}, &RemoteConnection{"192.168.1.1", "myuser", true}, false},
// If nothing is marked as isDefault and there is more than one connection, error should occur
{"nodefault", fields{Connections: noDefault}, nil, true},
// if nothing is marked as isDefault but there is only one connection, the one connection is considered the default
{"single", fields{Connections: none}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &RemoteConfig{
Connections: tt.fields.Connections,
}
got, err := r.GetDefault()
if (err != nil) != tt.wantErr {
t.Errorf("RemoteConfig.GetDefault() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RemoteConfig.GetDefault() = %v, want %v", got, tt.want)
}
})
}
}
func TestRemoteConfig_GetRemoteConnection(t *testing.T) {
type fields struct {
Connections map[string]RemoteConnection
}
type args struct {
name string
}
blank := make(map[string]RemoteConnection)
tests := []struct {
name string
fields fields
args args
want *RemoteConnection
wantErr bool
}{
// Good connection
{"goodhomer", fields{Connections: makeGoodResult().Connections}, args{name: "homer"}, &RemoteConnection{"192.168.1.1", "myuser", true}, false},
// Good connection
{"goodbart", fields{Connections: makeGoodResult().Connections}, args{name: "bart"}, &RemoteConnection{"foobar.com", "root", false}, false},
// Getting an unknown connection should result in error
{"noexist", fields{Connections: makeGoodResult().Connections}, args{name: "foobar"}, nil, true},
// Getting a connection when there are none should result in an error
{"none", fields{Connections: blank}, args{name: "foobar"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &RemoteConfig{
Connections: tt.fields.Connections,
}
got, err := r.GetRemoteConnection(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("RemoteConfig.GetRemoteConnection() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("RemoteConfig.GetRemoteConnection() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,14 @@
package remoteclientconfig
import "errors"
var (
// ErrNoDefaultConnection no default connection is defined in the podman-remote.conf file
ErrNoDefaultConnection = errors.New("no default connection is defined")
// ErrNoDefinedConnections no connections are defined in the podman-remote.conf file
ErrNoDefinedConnections = errors.New("no remote connections have been defined")
// ErrConnectionNotFound unable to lookup connection by name
ErrConnectionNotFound = errors.New("remote connection not found by name")
// ErrNoConfigationFile no config file found
ErrNoConfigationFile = errors.New("no configuration file found")
)

View File

@ -0,0 +1,47 @@
% podman-remote.conf(5)
## NAME
podman-remote.conf - configuration file for the podman remote client
## DESCRIPTION
The libpod.conf file is the default configuration file for all tools using
libpod to manage containers.
The podman-remote.conf file is the default configuration file for the podman
remote client. It is in the TOML format. It is primarily used to keep track
of the user's remote connections.
## CONNECTION OPTIONS
**destination** = ""
The hostname or IP address of the remote system
**username** = ""
The username to use when connecting to the remote system
**default** = bool
Denotes whether the connection is the default connection for the user. The default connection
is used when the user does not specify a destination or connection name to `podman`.
## EXAMPLE
The following example depicts a configuration file with two connections. One of the connections
is designated as the default connection.
```
[connections]
[connections.host1]
destination = "host1"
username = "homer"
default = true
[connections.host2]
destination = "192.168.122.133"
username = "fedora"
```
## FILES
`/$HOME/.config/containers/podman-remote.conf`, default location for the podman remote
configuration file
## HISTORY
May 2019, Originally compiled by Brent Baude<bbaude@redhat.com>

View File

@ -6,42 +6,52 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/varlink/go/varlink" "github.com/varlink/go/varlink"
) )
var remoteEndpoint *Endpoint var remoteEndpoint *Endpoint
func (r RemoteRuntime) RemoteEndpoint() (remoteEndpoint *Endpoint, err error) { func (r RemoteRuntime) RemoteEndpoint() (remoteEndpoint *Endpoint, err error) {
if remoteEndpoint == nil { remoteConfigConnections, _ := remoteclientconfig.ReadRemoteConfig(r.config)
remoteEndpoint = &Endpoint{Unknown, ""}
} else {
return remoteEndpoint, nil
}
// I'm leaving this here for now as a document of the birdge format. It can be removed later once the bridge // If the user defines an env variable for podman_varlink_bridge
// function is more flushed out. // we use that as passed.
// bridge := `ssh -T root@192.168.122.1 "/usr/bin/varlink -A '/usr/bin/podman varlink \$VARLINK_ADDRESS' bridge"` if bridge := os.Getenv("PODMAN_VARLINK_BRIDGE"); bridge != "" {
if len(r.cmd.RemoteHost) > 0 { logrus.Debug("creating a varlink bridge based on env variable")
// The user has provided a remote host endpoint remoteEndpoint, err = newBridgeConnection(bridge, nil, r.cmd.LogLevel)
// if an environment variable for podman_varlink_address is defined,
// we used that as passed
} else if address := os.Getenv("PODMAN_VARLINK_ADDRESS"); address != "" {
logrus.Debug("creating a varlink address based on env variable: %s", address)
remoteEndpoint, err = newSocketConnection(address)
// if the user provides a remote host, we use it to configure a bridge connection
} else if len(r.cmd.RemoteHost) > 0 {
logrus.Debug("creating a varlink bridge based on user input")
if len(r.cmd.RemoteUserName) < 1 { if len(r.cmd.RemoteUserName) < 1 {
return nil, errors.New("you must provide a username when providing a remote host name") return nil, errors.New("you must provide a username when providing a remote host name")
} }
remoteEndpoint.Type = BridgeConnection rc := remoteclientconfig.RemoteConnection{r.cmd.RemoteHost, r.cmd.RemoteUserName, false}
remoteEndpoint.Connection = fmt.Sprintf( remoteEndpoint, err = newBridgeConnection("", &rc, r.cmd.LogLevel)
`ssh -T %s@%s /usr/bin/varlink -A \'/usr/bin/podman --log-level=%s varlink \\\$VARLINK_ADDRESS\' bridge`, // if the user has a config file with connections in it
r.cmd.RemoteUserName, r.cmd.RemoteHost, r.cmd.LogLevel) } else if len(remoteConfigConnections.Connections) > 0 {
logrus.Debug("creating a varlink bridge based configuration file")
} else if bridge := os.Getenv("PODMAN_VARLINK_BRIDGE"); bridge != "" { var rc *remoteclientconfig.RemoteConnection
remoteEndpoint.Type = BridgeConnection if len(r.cmd.ConnectionName) > 0 {
remoteEndpoint.Connection = bridge rc, err = remoteConfigConnections.GetRemoteConnection(r.cmd.ConnectionName)
} else { } else {
address := os.Getenv("PODMAN_VARLINK_ADDRESS") rc, err = remoteConfigConnections.GetDefault()
if address == "" {
address = DefaultAddress
} }
remoteEndpoint.Type = DirectConnection if err != nil {
remoteEndpoint.Connection = address return nil, err
}
remoteEndpoint, err = newBridgeConnection("", rc, r.cmd.LogLevel)
// last resort is to make a socket connection with the default varlink address for root user
} else {
logrus.Debug("creating a varlink address based default root address")
remoteEndpoint, err = newSocketConnection(DefaultAddress)
} }
return return
} }
@ -72,3 +82,12 @@ func (r RemoteRuntime) RefreshConnection() error {
r.Conn = newConn r.Conn = newConn
return nil return nil
} }
// newSocketConnection returns an endpoint for a uds based connection
func newSocketConnection(address string) (*Endpoint, error) {
endpoint := Endpoint{
Type: DirectConnection,
Connection: address,
}
return &endpoint, nil
}

View File

@ -0,0 +1,30 @@
// +build linux darwin
// +build remoteclient
package adapter
import (
"fmt"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/pkg/errors"
)
// newBridgeConnection creates a bridge type endpoint with username, destination, and log-level
func newBridgeConnection(formattedBridge string, remoteConn *remoteclientconfig.RemoteConnection, logLevel string) (*Endpoint, error) {
endpoint := Endpoint{
Type: BridgeConnection,
}
if len(formattedBridge) < 1 && remoteConn == nil {
return nil, errors.New("bridge connections must either be created by string or remoteconnection")
}
if len(formattedBridge) > 0 {
endpoint.Connection = formattedBridge
return &endpoint, nil
}
endpoint.Connection = fmt.Sprintf(
`ssh -T %s@%s -- /usr/bin/varlink -A \'/usr/bin/podman --log-level=%s varlink \\\$VARLINK_ADDRESS\' bridge`,
remoteConn.Username, remoteConn.Destination, logLevel)
return &endpoint, nil
}

View File

@ -0,0 +1,15 @@
// +build remoteclient
package adapter
import (
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/containers/libpod/libpod"
)
func newBridgeConnection(formattedBridge string, remoteConn *remoteclientconfig.RemoteConnection, logLevel string) (*Endpoint, error) {
// TODO
// Unix and Windows appear to quote their ssh implementations differently therefore once we figure out what
// windows ssh is doing here, we can then get the format correct.
return nil, libpod.ErrNotImplemented
}

View File

@ -20,6 +20,7 @@ import (
"github.com/containers/image/docker/reference" "github.com/containers/image/docker/reference"
"github.com/containers/image/types" "github.com/containers/image/types"
"github.com/containers/libpod/cmd/podman/cliconfig" "github.com/containers/libpod/cmd/podman/cliconfig"
"github.com/containers/libpod/cmd/podman/remoteclientconfig"
"github.com/containers/libpod/cmd/podman/varlink" "github.com/containers/libpod/cmd/podman/varlink"
"github.com/containers/libpod/libpod" "github.com/containers/libpod/libpod"
"github.com/containers/libpod/libpod/events" "github.com/containers/libpod/libpod/events"
@ -40,6 +41,7 @@ type RemoteRuntime struct {
Conn *varlink.Connection Conn *varlink.Connection
Remote bool Remote bool
cmd cliconfig.MainFlags cmd cliconfig.MainFlags
config io.Reader
} }
// LocalRuntime describes a typical libpod runtime // LocalRuntime describes a typical libpod runtime
@ -49,10 +51,35 @@ type LocalRuntime struct {
// GetRuntime returns a LocalRuntime struct with the actual runtime embedded in it // GetRuntime returns a LocalRuntime struct with the actual runtime embedded in it
func GetRuntime(ctx context.Context, c *cliconfig.PodmanCommand) (*LocalRuntime, error) { func GetRuntime(ctx context.Context, c *cliconfig.PodmanCommand) (*LocalRuntime, error) {
var (
customConfig bool
err error
f *os.File
)
runtime := RemoteRuntime{ runtime := RemoteRuntime{
Remote: true, Remote: true,
cmd: c.GlobalFlags, cmd: c.GlobalFlags,
} }
configPath := remoteclientconfig.GetConfigFilePath()
if len(c.GlobalFlags.RemoteConfigFilePath) > 0 {
configPath = c.GlobalFlags.RemoteConfigFilePath
customConfig = true
}
f, err = os.Open(configPath)
if err != nil {
// If user does not explicitly provide a configuration file path and we cannot
// find a default, no error should occur.
if os.IsNotExist(err) && !customConfig {
logrus.Debugf("unable to load configuration file at %s", configPath)
runtime.config = nil
} else {
return nil, errors.Wrapf(err, "unable to load configuration file at %s", configPath)
}
} else {
// create the io reader for the remote client
runtime.config = bufio.NewReader(f)
}
conn, err := runtime.Connect() conn, err := runtime.Connect()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -3,7 +3,7 @@
# #
# TODO: no release, can we find an alternative? # TODO: no release, can we find an alternative?
github.com/Azure/go-ansiterm d6e3b3328b783f23731bc4d058875b0371ff8109 github.com/Azure/go-ansiterm d6e3b3328b783f23731bc4d058875b0371ff8109
github.com/BurntSushi/toml v0.2.0 github.com/BurntSushi/toml v0.3.1
github.com/Microsoft/go-winio v0.4.11 github.com/Microsoft/go-winio v0.4.11
github.com/Microsoft/hcsshim v0.8.3 github.com/Microsoft/hcsshim v0.8.3
github.com/blang/semver v3.5.0 github.com/blang/semver v3.5.0

View File

@ -1,14 +1,21 @@
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE The MIT License (MIT)
Version 2, December 2004
Copyright (C) 2004 Sam Hocevar <sam@hocevar.net> Copyright (c) 2013 TOML authors
Everyone is permitted to copy and distribute verbatim or modified Permission is hereby granted, free of charge, to any person obtaining a copy
copies of this license document, and changing it is allowed as long of this software and associated documentation files (the "Software"), to deal
as the name is changed. in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE The above copyright notice and this permission notice shall be included in
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION all copies or substantial portions of the Software.
0. You just DO WHAT THE FUCK YOU WANT TO.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -6,12 +6,12 @@ packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data `encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.) representations. (There is an example of this below.)
Spec: https://github.com/mojombo/toml Spec: https://github.com/toml-lang/toml
Compatible with TOML version Compatible with TOML version
[v0.2.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.2.0.md) [v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: http://godoc.org/github.com/BurntSushi/toml Documentation: https://godoc.org/github.com/BurntSushi/toml
Installation: Installation:
@ -26,8 +26,7 @@ go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml tomlv some-toml-file.toml
``` ```
[![Build status](https://api.travis-ci.org/BurntSushi/toml.png)](https://travis-ci.org/BurntSushi/toml) [![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
### Testing ### Testing
@ -217,4 +216,3 @@ Note that a case insensitive match will be tried if an exact match can't be
found. found.
A working example of the above can be found in `_examples/example.{go,toml}`. A working example of the above can be found in `_examples/example.{go,toml}`.

View File

@ -10,7 +10,9 @@ import (
"time" "time"
) )
var e = fmt.Errorf func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a // Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves. // TOML description of themselves.
@ -103,6 +105,13 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
// This decoder will not handle cyclic types. If a cyclic type is passed, // This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate. // `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) { func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data) p, err := parse(data)
if err != nil { if err != nil {
return MetaData{}, err return MetaData{}, err
@ -111,7 +120,7 @@ func Decode(data string, v interface{}) (MetaData, error) {
p.mapping, p.types, p.ordered, p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil, make(map[string]bool, len(p.ordered)), nil,
} }
return md, md.unify(p.mapping, rvalue(v)) return md, md.unify(p.mapping, indirect(rv))
} }
// DecodeFile is just like Decode, except it will automatically read the // DecodeFile is just like Decode, except it will automatically read the
@ -211,7 +220,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Interface: case reflect.Interface:
// we only support empty interfaces. // we only support empty interfaces.
if rv.NumMethod() > 0 { if rv.NumMethod() > 0 {
return e("Unsupported type '%s'.", rv.Kind()) return e("unsupported type %s", rv.Type())
} }
return md.unifyAnything(data, rv) return md.unifyAnything(data, rv)
case reflect.Float32: case reflect.Float32:
@ -219,7 +228,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
case reflect.Float64: case reflect.Float64:
return md.unifyFloat64(data, rv) return md.unifyFloat64(data, rv)
} }
return e("Unsupported type '%s'.", rv.Kind()) return e("unsupported type %s", rv.Kind())
} }
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
@ -228,7 +237,8 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
if mapping == nil { if mapping == nil {
return nil return nil
} }
return mismatch(rv, "map", mapping) return e("type mismatch for %s: expected table but found %T",
rv.Type().String(), mapping)
} }
for key, datum := range tmap { for key, datum := range tmap {
@ -253,14 +263,13 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
md.decoded[md.context.add(key).String()] = true md.decoded[md.context.add(key).String()] = true
md.context = append(md.context, key) md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil { if err := md.unify(datum, subv); err != nil {
return e("Type mismatch for '%s.%s': %s", return err
rv.Type().String(), f.name, err)
} }
md.context = md.context[0 : len(md.context)-1] md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" { } else if f.name != "" {
// Bad user! No soup for you! // Bad user! No soup for you!
return e("Field '%s.%s' is unexported, and therefore cannot "+ return e("cannot write unexported field %s.%s",
"be loaded with reflection.", rv.Type().String(), f.name) rv.Type().String(), f.name)
} }
} }
} }
@ -378,15 +387,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
// No bounds checking necessary. // No bounds checking necessary.
case reflect.Int8: case reflect.Int8:
if num < math.MinInt8 || num > math.MaxInt8 { if num < math.MinInt8 || num > math.MaxInt8 {
return e("Value '%d' is out of range for int8.", num) return e("value %d is out of range for int8", num)
} }
case reflect.Int16: case reflect.Int16:
if num < math.MinInt16 || num > math.MaxInt16 { if num < math.MinInt16 || num > math.MaxInt16 {
return e("Value '%d' is out of range for int16.", num) return e("value %d is out of range for int16", num)
} }
case reflect.Int32: case reflect.Int32:
if num < math.MinInt32 || num > math.MaxInt32 { if num < math.MinInt32 || num > math.MaxInt32 {
return e("Value '%d' is out of range for int32.", num) return e("value %d is out of range for int32", num)
} }
} }
rv.SetInt(num) rv.SetInt(num)
@ -397,15 +406,15 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
// No bounds checking necessary. // No bounds checking necessary.
case reflect.Uint8: case reflect.Uint8:
if num < 0 || unum > math.MaxUint8 { if num < 0 || unum > math.MaxUint8 {
return e("Value '%d' is out of range for uint8.", num) return e("value %d is out of range for uint8", num)
} }
case reflect.Uint16: case reflect.Uint16:
if num < 0 || unum > math.MaxUint16 { if num < 0 || unum > math.MaxUint16 {
return e("Value '%d' is out of range for uint16.", num) return e("value %d is out of range for uint16", num)
} }
case reflect.Uint32: case reflect.Uint32:
if num < 0 || unum > math.MaxUint32 { if num < 0 || unum > math.MaxUint32 {
return e("Value '%d' is out of range for uint32.", num) return e("value %d is out of range for uint32", num)
} }
} }
rv.SetUint(unum) rv.SetUint(unum)
@ -471,7 +480,7 @@ func rvalue(v interface{}) reflect.Value {
// interest to us (like encoding.TextUnmarshaler). // interest to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value { func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
if v.CanAddr() { if v.CanSet() {
pv := v.Addr() pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok { if _, ok := pv.Interface().(TextUnmarshaler); ok {
return pv return pv
@ -496,10 +505,5 @@ func isUnifiable(rv reflect.Value) bool {
} }
func badtype(expected string, data interface{}) error { func badtype(expected string, data interface{}) error {
return e("Expected %s but found '%T'.", expected, data) return e("cannot load TOML value of type %T into a Go %s", data, expected)
}
func mismatch(user reflect.Value, expected string, data interface{}) error {
return e("Type mismatch for %s. Expected %s but found '%T'.",
user.Type().String(), expected, data)
} }

View File

@ -77,9 +77,8 @@ func (k Key) maybeQuoted(i int) string {
} }
if quote { if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\"" return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
} else {
return k[i]
} }
return k[i]
} }
func (k Key) add(piece string) Key { func (k Key) add(piece string) Key {

View File

@ -4,7 +4,7 @@ files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the the Primitive type, and querying the set of keys in a TOML document with the
MetaData type. MetaData type.
The specification implemented: https://github.com/mojombo/toml The specification implemented: https://github.com/toml-lang/toml
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the whether a file is a valid TOML document. It can also be used to print the

View File

@ -16,17 +16,17 @@ type tomlEncodeError struct{ error }
var ( var (
errArrayMixedElementTypes = errors.New( errArrayMixedElementTypes = errors.New(
"can't encode array with mixed element types") "toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New( errArrayNilElement = errors.New(
"can't encode array with nil element") "toml: cannot encode array with nil element")
errNonString = errors.New( errNonString = errors.New(
"can't encode a map with non-string key type") "toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New( errAnonNonStruct = errors.New(
"can't encode an anonymous field that is not a struct") "toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New( errArrayNoTable = errors.New(
"TOML array element can't contain a table") "toml: TOML array element cannot contain a table")
errNoKey = errors.New( errNoKey = errors.New(
"top-level values must be a Go map or struct") "toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing errAnything = errors.New("") // used in testing
) )
@ -148,7 +148,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
case reflect.Struct: case reflect.Struct:
enc.eTable(key, rv) enc.eTable(key, rv)
default: default:
panic(e("Unsupported type for key '%s': %s", key, k)) panic(e("unsupported type for key '%s': %s", key, k))
} }
} }
@ -160,7 +160,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
// Special case time.Time as a primitive. Has to come before // Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements // TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC. // encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.In(time.FixedZone("UTC", 0)).Format("2006-01-02T15:04:05Z")) enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return return
case TextMarshaler: case TextMarshaler:
// Special case. Use text marshaler if it's available for this value. // Special case. Use text marshaler if it's available for this value.
@ -191,7 +191,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
case reflect.String: case reflect.String:
enc.writeQuoted(rv.String()) enc.writeQuoted(rv.String())
default: default:
panic(e("Unexpected primitive type: %s", rv.Kind())) panic(e("unexpected primitive type: %s", rv.Kind()))
} }
} }
@ -315,10 +315,16 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
t := f.Type t := f.Type
switch t.Kind() { switch t.Kind() {
case reflect.Struct: case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index) addFields(t, frv, f.Index)
continue continue
}
case reflect.Ptr: case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct { if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if !frv.IsNil() { if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index) addFields(t.Elem(), frv.Elem(), f.Index)
} }
@ -347,17 +353,18 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
continue continue
} }
tag := sft.Tag.Get("toml") opts := getOptions(sft.Tag)
if tag == "-" { if opts.skip {
continue continue
} }
keyName, opts := getOptions(tag) keyName := sft.Name
if keyName == "" { if opts.name != "" {
keyName = sft.Name keyName = opts.name
} }
if _, ok := opts["omitempty"]; ok && isEmpty(sf) { if opts.omitempty && isEmpty(sf) {
continue continue
} else if _, ok := opts["omitzero"]; ok && isZero(sf) { }
if opts.omitzero && isZero(sf) {
continue continue
} }
@ -392,9 +399,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if typeEqual(tomlHash, tomlArrayType(rv)) { if typeEqual(tomlHash, tomlArrayType(rv)) {
return tomlArrayHash return tomlArrayHash
} else {
return tomlArray
} }
return tomlArray
case reflect.Ptr, reflect.Interface: case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem()) return tomlTypeOfGo(rv.Elem())
case reflect.String: case reflect.String:
@ -451,17 +457,30 @@ func tomlArrayType(rv reflect.Value) tomlType {
return firstType return firstType
} }
func getOptions(keyName string) (string, map[string]struct{}) { type tagOptions struct {
opts := make(map[string]struct{}) skip bool // "-"
ss := strings.Split(keyName, ",") name string
name := ss[0] omitempty bool
if len(ss) > 1 { omitzero bool
for _, opt := range ss {
opts[opt] = struct{}{}
}
} }
return name, opts func getOptions(tag reflect.StructTag) tagOptions {
t := tag.Get("toml")
if t == "-" {
return tagOptions{skip: true}
}
var opts tagOptions
parts := strings.Split(t, ",")
opts.name = parts[0]
for _, s := range parts[1:] {
switch s {
case "omitempty":
opts.omitempty = true
case "omitzero":
opts.omitzero = true
}
}
return opts
} }
func isZero(rv reflect.Value) bool { func isZero(rv reflect.Value) bool {

View File

@ -3,6 +3,7 @@ package toml
import ( import (
"fmt" "fmt"
"strings" "strings"
"unicode"
"unicode/utf8" "unicode/utf8"
) )
@ -29,10 +30,13 @@ const (
itemArrayTableEnd itemArrayTableEnd
itemKeyStart itemKeyStart
itemCommentStart itemCommentStart
itemInlineTableStart
itemInlineTableEnd
) )
const ( const (
eof = 0 eof = 0
comma = ','
tableStart = '[' tableStart = '['
tableEnd = ']' tableEnd = ']'
arrayTableStart = '[' arrayTableStart = '['
@ -41,12 +45,13 @@ const (
keySep = '=' keySep = '='
arrayStart = '[' arrayStart = '['
arrayEnd = ']' arrayEnd = ']'
arrayValTerm = ','
commentStart = '#' commentStart = '#'
stringStart = '"' stringStart = '"'
stringEnd = '"' stringEnd = '"'
rawStringStart = '\'' rawStringStart = '\''
rawStringEnd = '\'' rawStringEnd = '\''
inlineTableStart = '{'
inlineTableEnd = '}'
) )
type stateFn func(lx *lexer) stateFn type stateFn func(lx *lexer) stateFn
@ -55,11 +60,18 @@ type lexer struct {
input string input string
start int start int
pos int pos int
width int
line int line int
state stateFn state stateFn
items chan item items chan item
// Allow for backing up up to three runes.
// This is necessary because TOML contains 3-rune tokens (""" and ''').
prevWidths [3]int
nprev int // how many of prevWidths are in use
// If we emit an eof, we can still back up, but it is not OK to call
// next again.
atEOF bool
// A stack of state functions used to maintain context. // A stack of state functions used to maintain context.
// The idea is to reuse parts of the state machine in various places. // The idea is to reuse parts of the state machine in various places.
// For example, values can appear at the top level or within arbitrarily // For example, values can appear at the top level or within arbitrarily
@ -87,7 +99,7 @@ func (lx *lexer) nextItem() item {
func lex(input string) *lexer { func lex(input string) *lexer {
lx := &lexer{ lx := &lexer{
input: input + "\n", input: input,
state: lexTop, state: lexTop,
line: 1, line: 1,
items: make(chan item, 10), items: make(chan item, 10),
@ -102,7 +114,7 @@ func (lx *lexer) push(state stateFn) {
func (lx *lexer) pop() stateFn { func (lx *lexer) pop() stateFn {
if len(lx.stack) == 0 { if len(lx.stack) == 0 {
return lx.errorf("BUG in lexer: no states to pop.") return lx.errorf("BUG in lexer: no states to pop")
} }
last := lx.stack[len(lx.stack)-1] last := lx.stack[len(lx.stack)-1]
lx.stack = lx.stack[0 : len(lx.stack)-1] lx.stack = lx.stack[0 : len(lx.stack)-1]
@ -124,16 +136,25 @@ func (lx *lexer) emitTrim(typ itemType) {
} }
func (lx *lexer) next() (r rune) { func (lx *lexer) next() (r rune) {
if lx.atEOF {
panic("next called after EOF")
}
if lx.pos >= len(lx.input) { if lx.pos >= len(lx.input) {
lx.width = 0 lx.atEOF = true
return eof return eof
} }
if lx.input[lx.pos] == '\n' { if lx.input[lx.pos] == '\n' {
lx.line++ lx.line++
} }
r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:]) lx.prevWidths[2] = lx.prevWidths[1]
lx.pos += lx.width lx.prevWidths[1] = lx.prevWidths[0]
if lx.nprev < 3 {
lx.nprev++
}
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
lx.prevWidths[0] = w
lx.pos += w
return r return r
} }
@ -142,9 +163,20 @@ func (lx *lexer) ignore() {
lx.start = lx.pos lx.start = lx.pos
} }
// backup steps back one rune. Can be called only once per call of next. // backup steps back one rune. Can be called only twice between calls to next.
func (lx *lexer) backup() { func (lx *lexer) backup() {
lx.pos -= lx.width if lx.atEOF {
lx.atEOF = false
return
}
if lx.nprev < 1 {
panic("backed up too far")
}
w := lx.prevWidths[0]
lx.prevWidths[0] = lx.prevWidths[1]
lx.prevWidths[1] = lx.prevWidths[2]
lx.nprev--
lx.pos -= w
if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' {
lx.line-- lx.line--
} }
@ -166,6 +198,19 @@ func (lx *lexer) peek() rune {
return r return r
} }
// skip ignores all input that matches the given predicate.
func (lx *lexer) skip(pred func(rune) bool) {
for {
r := lx.next()
if pred(r) {
continue
}
lx.backup()
lx.ignore()
return
}
}
// errorf stops all lexing by emitting an error and returning `nil`. // errorf stops all lexing by emitting an error and returning `nil`.
// Note that any value that is a character is escaped if it's a special // Note that any value that is a character is escaped if it's a special
// character (newlines, tabs, etc.). // character (newlines, tabs, etc.).
@ -184,7 +229,6 @@ func lexTop(lx *lexer) stateFn {
if isWhitespace(r) || isNL(r) { if isWhitespace(r) || isNL(r) {
return lexSkip(lx, lexTop) return lexSkip(lx, lexTop)
} }
switch r { switch r {
case commentStart: case commentStart:
lx.push(lexTop) lx.push(lexTop)
@ -193,7 +237,7 @@ func lexTop(lx *lexer) stateFn {
return lexTableStart return lexTableStart
case eof: case eof:
if lx.pos > lx.start { if lx.pos > lx.start {
return lx.errorf("Unexpected EOF.") return lx.errorf("unexpected EOF")
} }
lx.emit(itemEOF) lx.emit(itemEOF)
return nil return nil
@ -222,11 +266,11 @@ func lexTopEnd(lx *lexer) stateFn {
lx.ignore() lx.ignore()
return lexTop return lexTop
case r == eof: case r == eof:
lx.ignore() lx.emit(itemEOF)
return lexTop return nil
} }
return lx.errorf("Expected a top-level item to end with a new line, "+ return lx.errorf("expected a top-level item to end with a newline, "+
"comment or EOF, but got %q instead.", r) "comment, or EOF, but got %q instead", r)
} }
// lexTable lexes the beginning of a table. Namely, it makes sure that // lexTable lexes the beginning of a table. Namely, it makes sure that
@ -253,21 +297,22 @@ func lexTableEnd(lx *lexer) stateFn {
func lexArrayTableEnd(lx *lexer) stateFn { func lexArrayTableEnd(lx *lexer) stateFn {
if r := lx.next(); r != arrayTableEnd { if r := lx.next(); r != arrayTableEnd {
return lx.errorf("Expected end of table array name delimiter %q, "+ return lx.errorf("expected end of table array name delimiter %q, "+
"but got %q instead.", arrayTableEnd, r) "but got %q instead", arrayTableEnd, r)
} }
lx.emit(itemArrayTableEnd) lx.emit(itemArrayTableEnd)
return lexTopEnd return lexTopEnd
} }
func lexTableNameStart(lx *lexer) stateFn { func lexTableNameStart(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.peek(); { switch r := lx.peek(); {
case r == tableEnd || r == eof: case r == tableEnd || r == eof:
return lx.errorf("Unexpected end of table name. (Table names cannot " + return lx.errorf("unexpected end of table name " +
"be empty.)") "(table names cannot be empty)")
case r == tableSep: case r == tableSep:
return lx.errorf("Unexpected table separator. (Table names cannot " + return lx.errorf("unexpected table separator " +
"be empty.)") "(table names cannot be empty)")
case r == stringStart || r == rawStringStart: case r == stringStart || r == rawStringStart:
lx.ignore() lx.ignore()
lx.push(lexTableNameEnd) lx.push(lexTableNameEnd)
@ -277,24 +322,22 @@ func lexTableNameStart(lx *lexer) stateFn {
} }
} }
// lexTableName lexes the name of a table. It assumes that at least one // lexBareTableName lexes the name of a table. It assumes that at least one
// valid character for the table has already been read. // valid character for the table has already been read.
func lexBareTableName(lx *lexer) stateFn { func lexBareTableName(lx *lexer) stateFn {
switch r := lx.next(); { r := lx.next()
case isBareKeyChar(r): if isBareKeyChar(r) {
return lexBareTableName return lexBareTableName
case r == tableSep || r == tableEnd:
lx.backup()
lx.emitTrim(itemText)
return lexTableNameEnd
default:
return lx.errorf("Bare keys cannot contain %q.", r)
} }
lx.backup()
lx.emit(itemText)
return lexTableNameEnd
} }
// lexTableNameEnd reads the end of a piece of a table name, optionally // lexTableNameEnd reads the end of a piece of a table name, optionally
// consuming whitespace. // consuming whitespace.
func lexTableNameEnd(lx *lexer) stateFn { func lexTableNameEnd(lx *lexer) stateFn {
lx.skip(isWhitespace)
switch r := lx.next(); { switch r := lx.next(); {
case isWhitespace(r): case isWhitespace(r):
return lexTableNameEnd return lexTableNameEnd
@ -304,8 +347,8 @@ func lexTableNameEnd(lx *lexer) stateFn {
case r == tableEnd: case r == tableEnd:
return lx.pop() return lx.pop()
default: default:
return lx.errorf("Expected '.' or ']' to end table name, but got %q "+ return lx.errorf("expected '.' or ']' to end table name, "+
"instead.", r) "but got %q instead", r)
} }
} }
@ -315,7 +358,7 @@ func lexKeyStart(lx *lexer) stateFn {
r := lx.peek() r := lx.peek()
switch { switch {
case r == keySep: case r == keySep:
return lx.errorf("Unexpected key separator %q.", keySep) return lx.errorf("unexpected key separator %q", keySep)
case isWhitespace(r) || isNL(r): case isWhitespace(r) || isNL(r):
lx.next() lx.next()
return lexSkip(lx, lexKeyStart) return lexSkip(lx, lexKeyStart)
@ -338,14 +381,15 @@ func lexBareKey(lx *lexer) stateFn {
case isBareKeyChar(r): case isBareKeyChar(r):
return lexBareKey return lexBareKey
case isWhitespace(r): case isWhitespace(r):
lx.emitTrim(itemText) lx.backup()
lx.emit(itemText)
return lexKeyEnd return lexKeyEnd
case r == keySep: case r == keySep:
lx.backup() lx.backup()
lx.emitTrim(itemText) lx.emit(itemText)
return lexKeyEnd return lexKeyEnd
default: default:
return lx.errorf("Bare keys cannot contain %q.", r) return lx.errorf("bare keys cannot contain %q", r)
} }
} }
@ -358,7 +402,7 @@ func lexKeyEnd(lx *lexer) stateFn {
case isWhitespace(r): case isWhitespace(r):
return lexSkip(lx, lexKeyEnd) return lexSkip(lx, lexKeyEnd)
default: default:
return lx.errorf("Expected key separator %q, but got %q instead.", return lx.errorf("expected key separator %q, but got %q instead",
keySep, r) keySep, r)
} }
} }
@ -368,19 +412,25 @@ func lexKeyEnd(lx *lexer) stateFn {
// After a value is lexed, the last state on the next is popped and returned. // After a value is lexed, the last state on the next is popped and returned.
func lexValue(lx *lexer) stateFn { func lexValue(lx *lexer) stateFn {
// We allow whitespace to precede a value, but NOT newlines. // We allow whitespace to precede a value, but NOT newlines.
// In array syntax, the array states are responsible for ignoring new // In array syntax, the array states are responsible for ignoring newlines.
// lines.
r := lx.next() r := lx.next()
if isWhitespace(r) {
return lexSkip(lx, lexValue)
}
switch { switch {
case r == arrayStart: case isWhitespace(r):
return lexSkip(lx, lexValue)
case isDigit(r):
lx.backup() // avoid an extra state and use the same as above
return lexNumberOrDateStart
}
switch r {
case arrayStart:
lx.ignore() lx.ignore()
lx.emit(itemArray) lx.emit(itemArray)
return lexArrayValue return lexArrayValue
case r == stringStart: case inlineTableStart:
lx.ignore()
lx.emit(itemInlineTableStart)
return lexInlineTableValue
case stringStart:
if lx.accept(stringStart) { if lx.accept(stringStart) {
if lx.accept(stringStart) { if lx.accept(stringStart) {
lx.ignore() // Ignore """ lx.ignore() // Ignore """
@ -390,7 +440,7 @@ func lexValue(lx *lexer) stateFn {
} }
lx.ignore() // ignore the '"' lx.ignore() // ignore the '"'
return lexString return lexString
case r == rawStringStart: case rawStringStart:
if lx.accept(rawStringStart) { if lx.accept(rawStringStart) {
if lx.accept(rawStringStart) { if lx.accept(rawStringStart) {
lx.ignore() // Ignore """ lx.ignore() // Ignore """
@ -400,19 +450,20 @@ func lexValue(lx *lexer) stateFn {
} }
lx.ignore() // ignore the "'" lx.ignore() // ignore the "'"
return lexRawString return lexRawString
case r == 't': case '+', '-':
return lexTrue
case r == 'f':
return lexFalse
case r == '-':
return lexNumberStart return lexNumberStart
case isDigit(r): case '.': // special error case, be kind to users
lx.backup() // avoid an extra state and use the same as above return lx.errorf("floats must start with a digit, not '.'")
return lexNumberOrDateStart
case r == '.': // special error case, be kind to users
return lx.errorf("Floats must start with a digit, not '.'.")
} }
return lx.errorf("Expected value but found %q instead.", r) if unicode.IsLetter(r) {
// Be permissive here; lexBool will give a nice error if the
// user wrote something like
// x = foo
// (i.e. not 'true' or 'false' but is something else word-like.)
lx.backup()
return lexBool
}
return lx.errorf("expected value but found %q instead", r)
} }
// lexArrayValue consumes one value in an array. It assumes that '[' or ',' // lexArrayValue consumes one value in an array. It assumes that '[' or ','
@ -425,10 +476,11 @@ func lexArrayValue(lx *lexer) stateFn {
case r == commentStart: case r == commentStart:
lx.push(lexArrayValue) lx.push(lexArrayValue)
return lexCommentStart return lexCommentStart
case r == arrayValTerm: case r == comma:
return lx.errorf("Unexpected array value terminator %q.", return lx.errorf("unexpected comma")
arrayValTerm)
case r == arrayEnd: case r == arrayEnd:
// NOTE(caleb): The spec isn't clear about whether you can have
// a trailing comma or not, so we'll allow it.
return lexArrayEnd return lexArrayEnd
} }
@ -437,8 +489,9 @@ func lexArrayValue(lx *lexer) stateFn {
return lexValue return lexValue
} }
// lexArrayValueEnd consumes the cruft between values of an array. Namely, // lexArrayValueEnd consumes everything between the end of an array value and
// it ignores whitespace and expects either a ',' or a ']'. // the next value (or the end of the array): it ignores whitespace and newlines
// and expects either a ',' or a ']'.
func lexArrayValueEnd(lx *lexer) stateFn { func lexArrayValueEnd(lx *lexer) stateFn {
r := lx.next() r := lx.next()
switch { switch {
@ -447,31 +500,88 @@ func lexArrayValueEnd(lx *lexer) stateFn {
case r == commentStart: case r == commentStart:
lx.push(lexArrayValueEnd) lx.push(lexArrayValueEnd)
return lexCommentStart return lexCommentStart
case r == arrayValTerm: case r == comma:
lx.ignore() lx.ignore()
return lexArrayValue // move on to the next value return lexArrayValue // move on to the next value
case r == arrayEnd: case r == arrayEnd:
return lexArrayEnd return lexArrayEnd
} }
return lx.errorf("Expected an array value terminator %q or an array "+ return lx.errorf(
"terminator %q, but got %q instead.", arrayValTerm, arrayEnd, r) "expected a comma or array terminator %q, but got %q instead",
arrayEnd, r,
)
} }
// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has // lexArrayEnd finishes the lexing of an array.
// just been consumed. // It assumes that a ']' has just been consumed.
func lexArrayEnd(lx *lexer) stateFn { func lexArrayEnd(lx *lexer) stateFn {
lx.ignore() lx.ignore()
lx.emit(itemArrayEnd) lx.emit(itemArrayEnd)
return lx.pop() return lx.pop()
} }
// lexInlineTableValue consumes one key/value pair in an inline table.
// It assumes that '{' or ',' have already been consumed. Whitespace is ignored.
func lexInlineTableValue(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValue)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValue)
return lexCommentStart
case r == comma:
return lx.errorf("unexpected comma")
case r == inlineTableEnd:
return lexInlineTableEnd
}
lx.backup()
lx.push(lexInlineTableValueEnd)
return lexKeyStart
}
// lexInlineTableValueEnd consumes everything between the end of an inline table
// key/value pair and the next pair (or the end of the table):
// it ignores whitespace and expects either a ',' or a '}'.
func lexInlineTableValueEnd(lx *lexer) stateFn {
r := lx.next()
switch {
case isWhitespace(r):
return lexSkip(lx, lexInlineTableValueEnd)
case isNL(r):
return lx.errorf("newlines not allowed within inline tables")
case r == commentStart:
lx.push(lexInlineTableValueEnd)
return lexCommentStart
case r == comma:
lx.ignore()
return lexInlineTableValue
case r == inlineTableEnd:
return lexInlineTableEnd
}
return lx.errorf("expected a comma or an inline table terminator %q, "+
"but got %q instead", inlineTableEnd, r)
}
// lexInlineTableEnd finishes the lexing of an inline table.
// It assumes that a '}' has just been consumed.
func lexInlineTableEnd(lx *lexer) stateFn {
lx.ignore()
lx.emit(itemInlineTableEnd)
return lx.pop()
}
// lexString consumes the inner contents of a string. It assumes that the // lexString consumes the inner contents of a string. It assumes that the
// beginning '"' has already been consumed and ignored. // beginning '"' has already been consumed and ignored.
func lexString(lx *lexer) stateFn { func lexString(lx *lexer) stateFn {
r := lx.next() r := lx.next()
switch { switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r): case isNL(r):
return lx.errorf("Strings cannot contain new lines.") return lx.errorf("strings cannot contain newlines")
case r == '\\': case r == '\\':
lx.push(lexString) lx.push(lexString)
return lexStringEscape return lexStringEscape
@ -488,11 +598,12 @@ func lexString(lx *lexer) stateFn {
// lexMultilineString consumes the inner contents of a string. It assumes that // lexMultilineString consumes the inner contents of a string. It assumes that
// the beginning '"""' has already been consumed and ignored. // the beginning '"""' has already been consumed and ignored.
func lexMultilineString(lx *lexer) stateFn { func lexMultilineString(lx *lexer) stateFn {
r := lx.next() switch lx.next() {
switch { case eof:
case r == '\\': return lx.errorf("unexpected EOF")
case '\\':
return lexMultilineStringEscape return lexMultilineStringEscape
case r == stringEnd: case stringEnd:
if lx.accept(stringEnd) { if lx.accept(stringEnd) {
if lx.accept(stringEnd) { if lx.accept(stringEnd) {
lx.backup() lx.backup()
@ -516,8 +627,10 @@ func lexMultilineString(lx *lexer) stateFn {
func lexRawString(lx *lexer) stateFn { func lexRawString(lx *lexer) stateFn {
r := lx.next() r := lx.next()
switch { switch {
case r == eof:
return lx.errorf("unexpected EOF")
case isNL(r): case isNL(r):
return lx.errorf("Strings cannot contain new lines.") return lx.errorf("strings cannot contain newlines")
case r == rawStringEnd: case r == rawStringEnd:
lx.backup() lx.backup()
lx.emit(itemRawString) lx.emit(itemRawString)
@ -529,12 +642,13 @@ func lexRawString(lx *lexer) stateFn {
} }
// lexMultilineRawString consumes a raw string. Nothing can be escaped in such // lexMultilineRawString consumes a raw string. Nothing can be escaped in such
// a string. It assumes that the beginning "'" has already been consumed and // a string. It assumes that the beginning "'''" has already been consumed and
// ignored. // ignored.
func lexMultilineRawString(lx *lexer) stateFn { func lexMultilineRawString(lx *lexer) stateFn {
r := lx.next() switch lx.next() {
switch { case eof:
case r == rawStringEnd: return lx.errorf("unexpected EOF")
case rawStringEnd:
if lx.accept(rawStringEnd) { if lx.accept(rawStringEnd) {
if lx.accept(rawStringEnd) { if lx.accept(rawStringEnd) {
lx.backup() lx.backup()
@ -559,12 +673,11 @@ func lexMultilineStringEscape(lx *lexer) stateFn {
// Handle the special case first: // Handle the special case first:
if isNL(lx.next()) { if isNL(lx.next()) {
return lexMultilineString return lexMultilineString
} else { }
lx.backup() lx.backup()
lx.push(lexMultilineString) lx.push(lexMultilineString)
return lexStringEscape(lx) return lexStringEscape(lx)
} }
}
func lexStringEscape(lx *lexer) stateFn { func lexStringEscape(lx *lexer) stateFn {
r := lx.next() r := lx.next()
@ -588,10 +701,9 @@ func lexStringEscape(lx *lexer) stateFn {
case 'U': case 'U':
return lexLongUnicodeEscape return lexLongUnicodeEscape
} }
return lx.errorf("Invalid escape character %q. Only the following "+ return lx.errorf("invalid escape character %q; only the following "+
"escape characters are allowed: "+ "escape characters are allowed: "+
"\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+ `\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX`, r)
"\\uXXXX and \\UXXXXXXXX.", r)
} }
func lexShortUnicodeEscape(lx *lexer) stateFn { func lexShortUnicodeEscape(lx *lexer) stateFn {
@ -599,8 +711,8 @@ func lexShortUnicodeEscape(lx *lexer) stateFn {
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
r = lx.next() r = lx.next()
if !isHexadecimal(r) { if !isHexadecimal(r) {
return lx.errorf("Expected four hexadecimal digits after '\\u', "+ return lx.errorf(`expected four hexadecimal digits after '\u', `+
"but got '%s' instead.", lx.current()) "but got %q instead", lx.current())
} }
} }
return lx.pop() return lx.pop()
@ -611,40 +723,43 @@ func lexLongUnicodeEscape(lx *lexer) stateFn {
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
r = lx.next() r = lx.next()
if !isHexadecimal(r) { if !isHexadecimal(r) {
return lx.errorf("Expected eight hexadecimal digits after '\\U', "+ return lx.errorf(`expected eight hexadecimal digits after '\U', `+
"but got '%s' instead.", lx.current()) "but got %q instead", lx.current())
} }
} }
return lx.pop() return lx.pop()
} }
// lexNumberOrDateStart consumes either a (positive) integer, float or // lexNumberOrDateStart consumes either an integer, a float, or datetime.
// datetime. It assumes that NO negative sign has been consumed.
func lexNumberOrDateStart(lx *lexer) stateFn { func lexNumberOrDateStart(lx *lexer) stateFn {
r := lx.next() r := lx.next()
if !isDigit(r) { if isDigit(r) {
if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.")
} else {
return lx.errorf("Expected a digit but got %q.", r)
}
}
return lexNumberOrDate return lexNumberOrDate
} }
switch r {
case '_':
return lexNumber
case 'e', 'E':
return lexFloat
case '.':
return lx.errorf("floats must start with a digit, not '.'")
}
return lx.errorf("expected a digit but got %q", r)
}
// lexNumberOrDate consumes either a (positive) integer, float or datetime. // lexNumberOrDate consumes either an integer, float or datetime.
func lexNumberOrDate(lx *lexer) stateFn { func lexNumberOrDate(lx *lexer) stateFn {
r := lx.next() r := lx.next()
switch { if isDigit(r) {
case r == '-':
if lx.pos-lx.start != 5 {
return lx.errorf("All ISO8601 dates must be in full Zulu form.")
}
return lexDateAfterYear
case isDigit(r):
return lexNumberOrDate return lexNumberOrDate
case r == '.': }
return lexFloatStart switch r {
case '-':
return lexDatetime
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
} }
lx.backup() lx.backup()
@ -652,46 +767,34 @@ func lexNumberOrDate(lx *lexer) stateFn {
return lx.pop() return lx.pop()
} }
// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format. // lexDatetime consumes a Datetime, to a first approximation.
// It assumes that "YYYY-" has already been consumed. // The parser validates that it matches one of the accepted formats.
func lexDateAfterYear(lx *lexer) stateFn { func lexDatetime(lx *lexer) stateFn {
formats := []rune{
// digits are '0'.
// everything else is direct equality.
'0', '0', '-', '0', '0',
'T',
'0', '0', ':', '0', '0', ':', '0', '0',
'Z',
}
for _, f := range formats {
r := lx.next() r := lx.next()
if f == '0' { if isDigit(r) {
if !isDigit(r) { return lexDatetime
return lx.errorf("Expected digit in ISO8601 datetime, "+
"but found %q instead.", r)
}
} else if f != r {
return lx.errorf("Expected %q in ISO8601 datetime, "+
"but found %q instead.", f, r)
} }
switch r {
case '-', 'T', ':', '.', 'Z', '+':
return lexDatetime
} }
lx.backup()
lx.emit(itemDatetime) lx.emit(itemDatetime)
return lx.pop() return lx.pop()
} }
// lexNumberStart consumes either an integer or a float. It assumes that // lexNumberStart consumes either an integer or a float. It assumes that a sign
// a negative sign has already been read, but that *no* digits have been // has already been read, but that *no* digits have been consumed.
// consumed. lexNumberStart will move to the appropriate integer or float // lexNumberStart will move to the appropriate integer or float states.
// states.
func lexNumberStart(lx *lexer) stateFn { func lexNumberStart(lx *lexer) stateFn {
// we MUST see a digit. Even floats have to start with a digit. // We MUST see a digit. Even floats have to start with a digit.
r := lx.next() r := lx.next()
if !isDigit(r) { if !isDigit(r) {
if r == '.' { if r == '.' {
return lx.errorf("Floats must start with a digit, not '.'.") return lx.errorf("floats must start with a digit, not '.'")
} else {
return lx.errorf("Expected a digit but got %q.", r)
} }
return lx.errorf("expected a digit but got %q", r)
} }
return lexNumber return lexNumber
} }
@ -699,11 +802,14 @@ func lexNumberStart(lx *lexer) stateFn {
// lexNumber consumes an integer or a float after seeing the first digit. // lexNumber consumes an integer or a float after seeing the first digit.
func lexNumber(lx *lexer) stateFn { func lexNumber(lx *lexer) stateFn {
r := lx.next() r := lx.next()
switch { if isDigit(r) {
case isDigit(r):
return lexNumber return lexNumber
case r == '.': }
return lexFloatStart switch r {
case '_':
return lexNumber
case '.', 'e', 'E':
return lexFloat
} }
lx.backup() lx.backup()
@ -711,60 +817,42 @@ func lexNumber(lx *lexer) stateFn {
return lx.pop() return lx.pop()
} }
// lexFloatStart starts the consumption of digits of a float after a '.'. // lexFloat consumes the elements of a float. It allows any sequence of
// Namely, at least one digit is required. // float-like characters, so floats emitted by the lexer are only a first
func lexFloatStart(lx *lexer) stateFn { // approximation and must be validated by the parser.
r := lx.next()
if !isDigit(r) {
return lx.errorf("Floats must have a digit after the '.', but got "+
"%q instead.", r)
}
return lexFloat
}
// lexFloat consumes the digits of a float after a '.'.
// Assumes that one digit has been consumed after a '.' already.
func lexFloat(lx *lexer) stateFn { func lexFloat(lx *lexer) stateFn {
r := lx.next() r := lx.next()
if isDigit(r) { if isDigit(r) {
return lexFloat return lexFloat
} }
switch r {
case '_', '.', '-', '+', 'e', 'E':
return lexFloat
}
lx.backup() lx.backup()
lx.emit(itemFloat) lx.emit(itemFloat)
return lx.pop() return lx.pop()
} }
// lexConst consumes the s[1:] in s. It assumes that s[0] has already been // lexBool consumes a bool string: 'true' or 'false.
// consumed. func lexBool(lx *lexer) stateFn {
func lexConst(lx *lexer, s string) stateFn { var rs []rune
for i := range s[1:] { for {
if r := lx.next(); r != rune(s[i+1]) { r := lx.next()
return lx.errorf("Expected %q, but found %q instead.", s[:i+1], if !unicode.IsLetter(r) {
s[:i]+string(r)) lx.backup()
break
} }
rs = append(rs, r)
} }
return nil s := string(rs)
} switch s {
case "true", "false":
// lexTrue consumes the "rue" in "true". It assumes that 't' has already
// been consumed.
func lexTrue(lx *lexer) stateFn {
if fn := lexConst(lx, "true"); fn != nil {
return fn
}
lx.emit(itemBool) lx.emit(itemBool)
return lx.pop() return lx.pop()
} }
return lx.errorf("expected value but found %q instead", s)
// lexFalse consumes the "alse" in "false". It assumes that 'f' has already
// been consumed.
func lexFalse(lx *lexer) stateFn {
if fn := lexConst(lx, "false"); fn != nil {
return fn
}
lx.emit(itemBool)
return lx.pop()
} }
// lexCommentStart begins the lexing of a comment. It will emit // lexCommentStart begins the lexing of a comment. It will emit
@ -834,13 +922,7 @@ func (itype itemType) String() string {
return "EOF" return "EOF"
case itemText: case itemText:
return "Text" return "Text"
case itemString: case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
return "String"
case itemRawString:
return "String"
case itemMultilineString:
return "String"
case itemRawMultilineString:
return "String" return "String"
case itemBool: case itemBool:
return "Bool" return "Bool"

View File

@ -2,7 +2,6 @@ package toml
import ( import (
"fmt" "fmt"
"log"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -81,7 +80,7 @@ func (p *parser) next() item {
} }
func (p *parser) bug(format string, v ...interface{}) { func (p *parser) bug(format string, v ...interface{}) {
log.Panicf("BUG: %s\n\n", fmt.Sprintf(format, v...)) panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
} }
func (p *parser) expect(typ itemType) item { func (p *parser) expect(typ itemType) item {
@ -179,10 +178,18 @@ func (p *parser) value(it item) (interface{}, tomlType) {
} }
p.bug("Expected boolean value, but got '%s'.", it.val) p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger: case itemInteger:
num, err := strconv.ParseInt(it.val, 10, 64) if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil { if err != nil {
// See comment below for floats describing why we make a // Distinguish integer values. Normally, it'd be a bug if the lexer
// distinction between a bug and a user error. // provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok && if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange { e.Err == strconv.ErrRange {
@ -194,29 +201,57 @@ func (p *parser) value(it item) (interface{}, tomlType) {
} }
return num, p.typeOfPrimitive(it) return num, p.typeOfPrimitive(it)
case itemFloat: case itemFloat:
num, err := strconv.ParseFloat(it.val, 64) parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil { if err != nil {
// Distinguish float values. Normally, it'd be a bug if the lexer
// provides an invalid float, but it's possible that the float is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
//
// This is also true for integers.
if e, ok := err.(*strconv.NumError); ok && if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange { e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+ p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val) "IEEE-754 floating-point numbers.", it.val)
} else { } else {
p.bug("Expected float value, but got '%s'.", it.val) p.panicf("Invalid float value: %q", it.val)
} }
} }
return num, p.typeOfPrimitive(it) return num, p.typeOfPrimitive(it)
case itemDatetime: case itemDatetime:
t, err := time.Parse("2006-01-02T15:04:05Z", it.val) var t time.Time
if err != nil { var ok bool
p.panicf("Invalid RFC3339 Zulu DateTime: '%s'.", it.val) var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
} }
return t, p.typeOfPrimitive(it) return t, p.typeOfPrimitive(it)
case itemArray: case itemArray:
@ -234,11 +269,75 @@ func (p *parser) value(it item) (interface{}, tomlType) {
types = append(types, typ) types = append(types, typ)
} }
return array, p.typeOfArray(types) return array, p.typeOfArray(types)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
} }
p.bug("Unexpected value type: %s", it.typ) p.bug("Unexpected value type: %s", it.typ)
panic("unreachable") panic("unreachable")
} }
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
}
return accept
}
// numPeriodsOK checks whether every period in s is followed by a digit.
func numPeriodsOK(s string) bool {
period := false
for _, r := range s {
if period && !isDigit(r) {
return false
}
period = r == '.'
}
return !period
}
// establishContext sets the current context of the parser, // establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is // where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter. // set depends on the value of the `array` parameter.

View File

@ -95,8 +95,8 @@ func typeFields(t reflect.Type) []field {
if sf.PkgPath != "" && !sf.Anonymous { // unexported if sf.PkgPath != "" && !sf.Anonymous { // unexported
continue continue
} }
name, _ := getOptions(sf.Tag.Get("toml")) opts := getOptions(sf.Tag)
if name == "-" { if opts.skip {
continue continue
} }
index := make([]int, len(f.index)+1) index := make([]int, len(f.index)+1)
@ -110,8 +110,9 @@ func typeFields(t reflect.Type) []field {
} }
// Record found field and index sequence. // Record found field and index sequence.
if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
tagged := name != "" tagged := opts.name != ""
name := opts.name
if name == "" { if name == "" {
name = sf.Name name = sf.Name
} }