From 45550858449935c831c037e5061912c7fb28175f Mon Sep 17 00:00:00 2001 From: Artem Andreenko Date: Tue, 13 Oct 2015 01:09:55 +0300 Subject: [PATCH] fix races in http cors License: MIT Signed-off-by: Artem Andreenko --- commands/http/handler.go | 62 +++++++++++++++++++++++++++++++---- commands/http/handler_test.go | 12 +++---- core/corehttp/commands.go | 35 +++++++++----------- 3 files changed, 76 insertions(+), 33 deletions(-) diff --git a/commands/http/handler.go b/commands/http/handler.go index 86be6c2a4..4ce1d39a6 100644 --- a/commands/http/handler.go +++ b/commands/http/handler.go @@ -11,6 +11,7 @@ import ( "runtime" "strconv" "strings" + "sync" cors "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/rs/cors" context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" @@ -68,8 +69,11 @@ type ServerConfig struct { // Headers is an optional map of headers that is written out. Headers map[string][]string - // CORSOpts is a set of options for CORS headers. - CORSOpts *cors.Options + // cORSOpts is a set of options for CORS headers. + cORSOpts *cors.Options + + // cORSOptsRWMutex is a RWMutex for read/write CORSOpts + cORSOptsRWMutex sync.RWMutex } func skipAPIHeader(h string) bool { @@ -93,7 +97,7 @@ func NewHandler(ctx cmds.Context, root *cmds.Command, cfg *ServerConfig) *Handle // Wrap the internal handler with CORS handling-middleware. // Create a handler for the API. internal := internalHandler{ctx, root, cfg} - c := cors.New(*cfg.CORSOpts) + c := cors.New(*cfg.cORSOpts) return &Handler{internal, c.Handler(internal)} } @@ -322,6 +326,51 @@ func sanitizedErrStr(err error) string { return s } +func NewServerConfig() *ServerConfig { + cfg := new(ServerConfig) + cfg.cORSOpts = new(cors.Options) + return cfg +} + +func (cfg ServerConfig) AllowedOrigins() []string { + cfg.cORSOptsRWMutex.RLock() + defer cfg.cORSOptsRWMutex.RUnlock() + return cfg.cORSOpts.AllowedOrigins +} + +func (cfg *ServerConfig) SetAllowedOrigins(origins ...string) { + cfg.cORSOptsRWMutex.Lock() + defer cfg.cORSOptsRWMutex.Unlock() + cfg.cORSOpts.AllowedOrigins = origins +} + +func (cfg *ServerConfig) AppendAllowedOrigins(origins ...string) { + cfg.cORSOptsRWMutex.Lock() + defer cfg.cORSOptsRWMutex.Unlock() + cfg.cORSOpts.AllowedOrigins = append(cfg.cORSOpts.AllowedOrigins, origins...) +} + +func (cfg ServerConfig) AllowedMethods() []string { + cfg.cORSOptsRWMutex.RLock() + defer cfg.cORSOptsRWMutex.RUnlock() + return []string(cfg.cORSOpts.AllowedMethods) +} + +func (cfg *ServerConfig) SetAllowedMethods(methods ...string) { + cfg.cORSOptsRWMutex.Lock() + defer cfg.cORSOptsRWMutex.Unlock() + if cfg.cORSOpts == nil { + cfg.cORSOpts = new(cors.Options) + } + cfg.cORSOpts.AllowedMethods = methods +} + +func (cfg *ServerConfig) SetAllowCredentials(flag bool) { + cfg.cORSOptsRWMutex.Lock() + defer cfg.cORSOptsRWMutex.Unlock() + cfg.cORSOpts.AllowCredentials = flag +} + // allowOrigin just stops the request if the origin is not allowed. // the CORS middleware apparently does not do this for us... func allowOrigin(r *http.Request, cfg *ServerConfig) bool { @@ -333,8 +382,8 @@ func allowOrigin(r *http.Request, cfg *ServerConfig) bool { if origin == "" { return true } - - for _, o := range cfg.CORSOpts.AllowedOrigins { + origins := cfg.AllowedOrigins() + for _, o := range origins { if o == "*" { // ok! you asked for it! return true } @@ -375,7 +424,8 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool { // check CORS ACAOs and pretend Referer works like an origin. // this is valid for many (most?) sane uses of the API in // other applications, and will have the desired effect. - for _, o := range cfg.CORSOpts.AllowedOrigins { + origins := cfg.AllowedOrigins() + for _, o := range origins { if o == "*" { // ok! you asked for it! return true } diff --git a/commands/http/handler_test.go b/commands/http/handler_test.go index b61a41457..86f1e8118 100644 --- a/commands/http/handler_test.go +++ b/commands/http/handler_test.go @@ -6,8 +6,6 @@ import ( "net/url" "testing" - cors "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/rs/cors" - cmds "github.com/ipfs/go-ipfs/commands" ipfscmd "github.com/ipfs/go-ipfs/core/commands" coremock "github.com/ipfs/go-ipfs/core/mock" @@ -28,12 +26,10 @@ func assertStatus(t *testing.T, actual, expected int) { } func originCfg(origins []string) *ServerConfig { - return &ServerConfig{ - CORSOpts: &cors.Options{ - AllowedOrigins: origins, - AllowedMethods: []string{"GET", "PUT", "POST"}, - }, - } + cfg := NewServerConfig() + cfg.SetAllowedOrigins(origins...) + cfg.SetAllowedMethods("GET", "PUT", "POST") + return cfg } type testCase struct { diff --git a/core/corehttp/commands.go b/core/corehttp/commands.go index 99f544905..882121c4e 100644 --- a/core/corehttp/commands.go +++ b/core/corehttp/commands.go @@ -7,8 +7,6 @@ import ( "strconv" "strings" - cors "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/rs/cors" - commands "github.com/ipfs/go-ipfs/commands" cmdsHttp "github.com/ipfs/go-ipfs/commands/http" core "github.com/ipfs/go-ipfs/core" @@ -41,10 +39,10 @@ func addCORSFromEnv(c *cmdsHttp.ServerConfig) { origin := os.Getenv(originEnvKey) if origin != "" { log.Warning(originEnvKeyDeprecate) - if c.CORSOpts == nil { - c.CORSOpts.AllowedOrigins = []string{origin} + if len(c.AllowedOrigins()) == 0 { + c.SetAllowedOrigins([]string{origin}...) } - c.CORSOpts.AllowedOrigins = append(c.CORSOpts.AllowedOrigins, origin) + c.AppendAllowedOrigins(origin) } } @@ -52,14 +50,14 @@ func addHeadersFromConfig(c *cmdsHttp.ServerConfig, nc *config.Config) { log.Info("Using API.HTTPHeaders:", nc.API.HTTPHeaders) if acao := nc.API.HTTPHeaders[cmdsHttp.ACAOrigin]; acao != nil { - c.CORSOpts.AllowedOrigins = acao + c.SetAllowedOrigins(acao...) } if acam := nc.API.HTTPHeaders[cmdsHttp.ACAMethods]; acam != nil { - c.CORSOpts.AllowedMethods = acam + c.SetAllowedMethods(acam...) } if acac := nc.API.HTTPHeaders[cmdsHttp.ACACredentials]; acac != nil { for _, v := range acac { - c.CORSOpts.AllowCredentials = (strings.ToLower(v) == "true") + c.SetAllowCredentials(strings.ToLower(v) == "true") } } @@ -68,13 +66,13 @@ func addHeadersFromConfig(c *cmdsHttp.ServerConfig, nc *config.Config) { func addCORSDefaults(c *cmdsHttp.ServerConfig) { // by default use localhost origins - if len(c.CORSOpts.AllowedOrigins) == 0 { - c.CORSOpts.AllowedOrigins = defaultLocalhostOrigins + if len(c.AllowedOrigins()) == 0 { + c.SetAllowedOrigins(defaultLocalhostOrigins...) } // by default, use GET, PUT, POST - if len(c.CORSOpts.AllowedMethods) == 0 { - c.CORSOpts.AllowedMethods = []string{"GET", "POST", "PUT"} + if len(c.AllowedMethods()) == 0 { + c.SetAllowedMethods("GET", "POST", "PUT") } } @@ -90,23 +88,22 @@ func patchCORSVars(c *cmdsHttp.ServerConfig, addr net.Addr) { } // we're listening on tcp/udp with ports. ("udp!?" you say? yeah... it happens...) - for i, o := range c.CORSOpts.AllowedOrigins { + origins := c.AllowedOrigins() + for i, o := range origins { // TODO: allow replacing . tricky, ip4 and ip6 and hostnames... if port != "" { o = strings.Replace(o, "", port, -1) } - c.CORSOpts.AllowedOrigins[i] = o + origins[i] = o } + c.SetAllowedOrigins(origins...) } func commandsOption(cctx commands.Context, command *commands.Command) ServeOption { return func(n *core.IpfsNode, l net.Listener, mux *http.ServeMux) (*http.ServeMux, error) { - cfg := &cmdsHttp.ServerConfig{ - CORSOpts: &cors.Options{ - AllowedMethods: []string{"GET", "POST", "PUT"}, - }, - } + cfg := cmdsHttp.NewServerConfig() + cfg.SetAllowedMethods("GET", "POST", "PUT") rcfg, err := n.Repo.Config() if err != nil { return nil, err