From c7a8d29f12c57ad5b214b39ad29f5d2a36839a3a Mon Sep 17 00:00:00 2001 From: Black-Hole1 Date: Wed, 31 May 2023 16:27:19 +0800 Subject: [PATCH] refactor: improve get ssh path duplicate code Signed-off-by: Black-Hole1 --- pkg/machine/applehv/machine.go | 7 +++---- pkg/machine/hyperv/machine.go | 9 +++------ pkg/machine/keys.go | 19 ++++++++++++------- pkg/machine/qemu/machine.go | 8 +++----- pkg/machine/wsl/machine.go | 26 ++++++++++---------------- pkg/util/utils.go | 14 +++----------- pkg/util/utils_test.go | 8 ++++++++ 7 files changed, 42 insertions(+), 49 deletions(-) diff --git a/pkg/machine/applehv/machine.go b/pkg/machine/applehv/machine.go index 06abea8672..58ea2676ea 100644 --- a/pkg/machine/applehv/machine.go +++ b/pkg/machine/applehv/machine.go @@ -17,7 +17,7 @@ import ( "github.com/containers/common/pkg/config" "github.com/containers/podman/v4/pkg/machine" - "github.com/containers/storage/pkg/homedir" + "github.com/containers/podman/v4/pkg/util" "github.com/docker/go-units" "github.com/sirupsen/logrus" ) @@ -123,8 +123,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) { } m.VfkitHelper = *vfhelper - sshDir := filepath.Join(homedir.Get(), ".ssh") - m.IdentityPath = filepath.Join(sshDir, m.Name) + m.IdentityPath = util.GetIdentityPath(m.Name) m.Rootful = opts.Rootful m.RemoteUsername = opts.Username @@ -142,7 +141,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) { // TODO localhost needs to be restored here uri := machine.SSHRemoteConnection.MakeSSHURL("192.168.64.2", fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername) uriRoot := machine.SSHRemoteConnection.MakeSSHURL("localhost", "/run/podman/podman.sock", strconv.Itoa(m.Port), "root") - identity := filepath.Join(sshDir, m.Name) + identity := m.IdentityPath uris := []url.URL{uri, uriRoot} names := []string{m.Name, m.Name + "-root"} diff --git a/pkg/machine/hyperv/machine.go b/pkg/machine/hyperv/machine.go index baba8d68db..d0473663e4 100644 --- a/pkg/machine/hyperv/machine.go +++ b/pkg/machine/hyperv/machine.go @@ -18,8 +18,8 @@ import ( "github.com/containers/common/pkg/config" "github.com/containers/libhvee/pkg/hypervctl" "github.com/containers/podman/v4/pkg/machine" + "github.com/containers/podman/v4/pkg/util" "github.com/containers/podman/v4/utils" - "github.com/containers/storage/pkg/homedir" "github.com/containers/storage/pkg/ioutils" "github.com/docker/go-units" "github.com/sirupsen/logrus" @@ -91,9 +91,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) { } m.NetworkHVSock = *networkHVSock m.ReadyHVSock = *eventHVSocket - - sshDir := filepath.Join(homedir.Get(), ".ssh") - m.IdentityPath = filepath.Join(sshDir, m.Name) + m.IdentityPath = util.GetIdentityPath(m.Name) // TODO This needs to be fixed in c-common m.RemoteUsername = "core" @@ -111,7 +109,6 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) { if len(opts.IgnitionPath) < 1 { uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername) uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(m.Port), "root") - identity := filepath.Join(sshDir, m.Name) uris := []url.URL{uri, uriRoot} names := []string{m.Name, m.Name + "-root"} @@ -123,7 +120,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) { } for i := 0; i < 2; i++ { - if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { + if err := machine.AddConnection(&uris[i], names[i], m.IdentityPath, opts.IsDefault && i == 0); err != nil { return false, err } } diff --git a/pkg/machine/keys.go b/pkg/machine/keys.go index fce405695a..16561df30b 100644 --- a/pkg/machine/keys.go +++ b/pkg/machine/keys.go @@ -33,18 +33,16 @@ func CreateSSHKeys(writeLocation string) (string, error) { return strings.TrimSuffix(string(b), "\n"), nil } -func CreateSSHKeysPrefix(dir string, file string, passThru bool, skipExisting bool, prefix ...string) (string, error) { - location := filepath.Join(dir, file) - - _, e := os.Stat(location) +func CreateSSHKeysPrefix(identityPath string, passThru bool, skipExisting bool, prefix ...string) (string, error) { + _, e := os.Stat(identityPath) if !skipExisting || errors.Is(e, os.ErrNotExist) { - if err := generatekeysPrefix(dir, file, passThru, prefix...); err != nil { + if err := generatekeysPrefix(identityPath, passThru, prefix...); err != nil { return "", err } } else { fmt.Println("Keys already exist, reusing") } - b, err := os.ReadFile(filepath.Join(dir, file) + ".pub") + b, err := os.ReadFile(identityPath + ".pub") if err != nil { return "", err } @@ -74,7 +72,14 @@ func generatekeys(writeLocation string) error { } // generatekeys creates an ed25519 set of keys -func generatekeysPrefix(dir string, file string, passThru bool, prefix ...string) error { +func generatekeysPrefix(identityPath string, passThru bool, prefix ...string) error { + dir := filepath.Dir(identityPath) + file := filepath.Base(identityPath) + + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("could not create ssh directory: %w", err) + } + args := append([]string{}, prefix[1:]...) args = append(args, sshCommand...) args = append(args, file) diff --git a/pkg/machine/qemu/machine.go b/pkg/machine/qemu/machine.go index 6cfc394620..390d23f078 100644 --- a/pkg/machine/qemu/machine.go +++ b/pkg/machine/qemu/machine.go @@ -27,8 +27,8 @@ import ( "github.com/containers/common/pkg/config" "github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/rootless" + "github.com/containers/podman/v4/pkg/util" "github.com/containers/podman/v4/utils" - "github.com/containers/storage/pkg/homedir" "github.com/containers/storage/pkg/ioutils" "github.com/digitalocean/go-qemu/qmp" "github.com/docker/go-units" @@ -242,8 +242,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { var ( key string ) - sshDir := filepath.Join(homedir.Get(), ".ssh") - v.IdentityPath = filepath.Join(sshDir, v.Name) + v.IdentityPath = util.GetIdentityPath(v.Name) v.Rootful = opts.Rootful switch opts.ImagePath { @@ -320,7 +319,6 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { if len(opts.IgnitionPath) < 1 { uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", v.UID), strconv.Itoa(v.Port), v.RemoteUsername) uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(v.Port), "root") - identity := filepath.Join(sshDir, v.Name) uris := []url.URL{uri, uriRoot} names := []string{v.Name, v.Name + "-root"} @@ -332,7 +330,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { } for i := 0; i < 2; i++ { - if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { + if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil { return false, err } } diff --git a/pkg/machine/wsl/machine.go b/pkg/machine/wsl/machine.go index b4674e6d46..16dfd5abde 100644 --- a/pkg/machine/wsl/machine.go +++ b/pkg/machine/wsl/machine.go @@ -21,6 +21,7 @@ import ( "github.com/containers/common/pkg/config" "github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/machine/wsl/wutil" + "github.com/containers/podman/v4/pkg/util" "github.com/containers/podman/v4/utils" "github.com/containers/storage/pkg/homedir" "github.com/containers/storage/pkg/ioutils" @@ -406,9 +407,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { } _ = setupWslProxyEnv() - homeDir := homedir.Get() - sshDir := filepath.Join(homeDir, ".ssh") - v.IdentityPath = filepath.Join(sshDir, v.Name) + v.IdentityPath = util.GetIdentityPath(v.Name) v.Rootful = opts.Rootful v.Version = currentMachineVersion @@ -438,7 +437,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { return false, err } - if err = createKeys(v, dist, sshDir); err != nil { + if err = createKeys(v, dist); err != nil { return false, err } @@ -449,7 +448,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) { return false, err } - if err := setupConnections(v, opts, sshDir); err != nil { + if err := setupConnections(v, opts); err != nil { return false, err } @@ -502,10 +501,9 @@ func (v *MachineVM) writeConfig() error { return nil } -func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) error { +func setupConnections(v *MachineVM, opts machine.InitOptions) error { uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootlessSock, strconv.Itoa(v.Port), v.RemoteUsername) uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootfulSock, strconv.Itoa(v.Port), "root") - identity := filepath.Join(sshDir, v.Name) uris := []url.URL{uri, uriRoot} names := []string{v.Name, v.Name + "-root"} @@ -517,7 +515,7 @@ func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) err } for i := 0; i < 2; i++ { - if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { + if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil { return err } } @@ -551,18 +549,14 @@ func provisionWSLDist(name string, imagePath string, prompt string) (string, err return dist, nil } -func createKeys(v *MachineVM, dist string, sshDir string) error { +func createKeys(v *MachineVM, dist string) error { user := v.RemoteUsername - if err := os.MkdirAll(sshDir, 0700); err != nil { - return fmt.Errorf("could not create ssh directory: %w", err) - } - if err := terminateDist(dist); err != nil { return fmt.Errorf("could not cycle WSL dist: %w", err) } - key, err := wslCreateKeys(sshDir, v.Name, dist) + key, err := wslCreateKeys(v.IdentityPath, dist) if err != nil { return fmt.Errorf("could not create ssh keys: %w", err) } @@ -972,8 +966,8 @@ func wslPipe(input string, dist string, arg ...string) error { return pipeCmdPassThrough("wsl", input, newArgs...) } -func wslCreateKeys(sshDir string, name string, dist string) (string, error) { - return machine.CreateSSHKeysPrefix(sshDir, name, true, true, "wsl", "-u", "root", "-d", dist) +func wslCreateKeys(identityPath string, dist string) (string, error) { + return machine.CreateSSHKeysPrefix(identityPath, true, true, "wsl", "-u", "root", "-d", dist) } func runCmdPassThrough(name string, arg ...string) error { diff --git a/pkg/util/utils.go b/pkg/util/utils.go index 2e52180a07..3f47c6b4cf 100644 --- a/pkg/util/utils.go +++ b/pkg/util/utils.go @@ -26,6 +26,7 @@ import ( "github.com/containers/podman/v4/pkg/rootless" "github.com/containers/podman/v4/pkg/signal" "github.com/containers/storage/pkg/directory" + "github.com/containers/storage/pkg/homedir" "github.com/containers/storage/pkg/idtools" stypes "github.com/containers/storage/types" securejoin "github.com/cyphar/filepath-securejoin" @@ -473,17 +474,8 @@ func ExitCode(err error) int { return 126 } -// HomeDir returns the home directory for the current user. -func HomeDir() (string, error) { - home := os.Getenv("HOME") - if home == "" { - usr, err := user.LookupId(fmt.Sprintf("%d", rootless.GetRootlessUID())) - if err != nil { - return "", fmt.Errorf("unable to resolve HOME directory: %w", err) - } - home = usr.HomeDir - } - return home, nil +func GetIdentityPath(name string) string { + return filepath.Join(homedir.Get(), ".ssh", name) } func Tmpdir() string { diff --git a/pkg/util/utils_test.go b/pkg/util/utils_test.go index df1722a003..0ef48a6f3a 100644 --- a/pkg/util/utils_test.go +++ b/pkg/util/utils_test.go @@ -2,9 +2,11 @@ package util import ( "fmt" + "path/filepath" "testing" "time" + "github.com/containers/storage/pkg/homedir" "github.com/opencontainers/runtime-spec/specs-go" "github.com/stretchr/testify/assert" ) @@ -56,6 +58,12 @@ func TestValidateSysctlBadSysctlWithExtraSpaces(t *testing.T) { assert.Equal(t, err.Error(), fmt.Sprintf(expectedError, strSlice2[1])) } +func TestGetIdentityPath(t *testing.T) { + name := "p-test" + identityPath := GetIdentityPath(name) + assert.Equal(t, identityPath, filepath.Join(homedir.Get(), ".ssh", name)) +} + func TestCoresToPeriodAndQuota(t *testing.T) { cores := 1.0 expectedPeriod := DefaultCPUPeriod