diff --git a/pkg/bindings/containers/attach.go b/pkg/bindings/containers/attach.go index 6b2aab45e9..ff42c9928a 100644 --- a/pkg/bindings/containers/attach.go +++ b/pkg/bindings/containers/attach.go @@ -111,37 +111,11 @@ func Attach(ctx context.Context, nameOrID string, stdin io.Reader, stdout io.Wri }() } - headers := make(http.Header) - headers.Add("Connection", "Upgrade") - headers.Add("Upgrade", "tcp") - - var socket net.Conn - socketSet := false - dialContext := conn.Client.Transport.(*http.Transport).DialContext - t := &http.Transport{ - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - c, err := dialContext(ctx, network, address) - if err != nil { - return nil, err - } - if !socketSet { - socket = c - socketSet = true - } - return c, err - }, - IdleConnTimeout: time.Duration(0), - } - conn.Client.Transport = t - response, err := conn.DoRequest(ctx, nil, http.MethodPost, "/containers/%s/attach", params, headers, nameOrID) + cw, socket, err := newUpgradeRequest(ctx, conn, nil, fmt.Sprintf("/containers/%s/attach", nameOrID), params) if err != nil { return err } - - if !response.IsSuccess() && !response.IsInformational() { - defer response.Body.Close() - return response.Process(nil) - } + defer socket.Close() if needTTY { winChange := make(chan os.Signal, 1) @@ -173,11 +147,7 @@ func Attach(ctx context.Context, nameOrID string, stdin io.Reader, stdout io.Wri logrus.Errorf("Failed to write input to service: %v", err) } if err == nil { - if closeWrite, ok := socket.(CloseWriter); ok { - if err := closeWrite.CloseWrite(); err != nil { - logrus.Warnf("Failed to close STDIN for writing: %v", err) - } - } + cw.CloseWrite() } stdinChan <- err }() @@ -210,7 +180,7 @@ func Attach(ctx context.Context, nameOrID string, stdin io.Reader, stdout io.Wri return err } - return nil + return <-stdoutChan } } } else { @@ -464,33 +434,11 @@ func ExecStartAndAttach(ctx context.Context, sessionID string, options *ExecStar return err } - var socket net.Conn - socketSet := false - dialContext := conn.Client.Transport.(*http.Transport).DialContext - t := &http.Transport{ - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - c, err := dialContext(ctx, network, address) - if err != nil { - return nil, err - } - if !socketSet { - socket = c - socketSet = true - } - return c, err - }, - IdleConnTimeout: time.Duration(0), - } - conn.Client.Transport = t - response, err := conn.DoRequest(ctx, bytes.NewReader(bodyJSON), http.MethodPost, "/exec/%s/start", nil, nil, sessionID) + cw, socket, err := newUpgradeRequest(ctx, conn, bytes.NewReader(bodyJSON), fmt.Sprintf("/exec/%s/start", sessionID), nil) if err != nil { return err } - defer response.Body.Close() - - if !response.IsSuccess() && !response.IsInformational() { - return response.Process(nil) - } + defer socket.Close() if needTTY { winChange := make(chan os.Signal, 1) @@ -513,12 +461,7 @@ func ExecStartAndAttach(ctx context.Context, sessionID string, options *ExecStar } if err == nil { - if closeWrite, ok := socket.(CloseWriter); ok { - logrus.Debugf("Closing STDIN") - if err := closeWrite.CloseWrite(); err != nil { - logrus.Warnf("Failed to close STDIN for writing: %v", err) - } - } + cw.CloseWrite() } }() } @@ -580,3 +523,66 @@ func ExecStartAndAttach(ctx context.Context, sessionID string, options *ExecStar } return nil } + +type closeWrite struct { + // sock is the underlying socket of the connection. + // Do not use that field directly. + sock net.Conn +} + +func (cw *closeWrite) CloseWrite() { + if closeWrite, ok := cw.sock.(CloseWriter); ok { + logrus.Debugf("Closing STDIN") + if err := closeWrite.CloseWrite(); err != nil { + logrus.Warnf("Failed to close STDIN for writing: %v", err) + } + } +} + +// newUpgradeRequest performs a new http Upgrade request, it return the closeWrite which should be used +// to close the STDIN side used and the ReadWriter which MUST be uses to write/read from the connection +// and which must closed when finished. Do not access the new.Conn in closeWrite directly. +func newUpgradeRequest(ctx context.Context, conn *bindings.Connection, body io.Reader, path string, params url.Values) (*closeWrite, io.ReadWriteCloser, error) { + headers := http.Header{ + "Connection": []string{"Upgrade"}, + "Upgrade": []string{"tcp"}, + } + + var socket net.Conn + socketSet := false + dialContext := conn.Client.Transport.(*http.Transport).DialContext + t := &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + c, err := dialContext(ctx, network, address) + if err != nil { + return nil, err + } + if !socketSet { + socket = c + socketSet = true + } + return c, err + }, + IdleConnTimeout: time.Duration(0), + } + conn.Client.Transport = t + response, err := conn.DoRequest(ctx, body, http.MethodPost, path, params, headers) + if err != nil { + return nil, nil, err + } + + if response.StatusCode != http.StatusSwitchingProtocols { + defer response.Body.Close() + if err := response.Process(nil); err != nil { + return nil, nil, err + } + return nil, nil, fmt.Errorf("incorrect server response code %d, expected %d", response.StatusCode, http.StatusSwitchingProtocols) + } + rw, ok := response.Body.(io.ReadWriteCloser) + if !ok { + response.Body.Close() + return nil, nil, errors.New("internal error: cannot cast to http response Body to io.ReadWriteCloser") + } + + return &closeWrite{sock: socket}, rw, nil +}