Merge pull request #18750 from BlackHole1/improve-ssh

refactor: improve get identity path duplicate code
This commit is contained in:
OpenShift Merge Robot
2023-06-07 08:13:17 -04:00
committed by GitHub
7 changed files with 42 additions and 49 deletions

View File

@ -17,7 +17,7 @@ import (
"github.com/containers/common/pkg/config" "github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/machine"
"github.com/containers/storage/pkg/homedir" "github.com/containers/podman/v4/pkg/util"
"github.com/docker/go-units" "github.com/docker/go-units"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -123,8 +123,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) {
} }
m.VfkitHelper = *vfhelper m.VfkitHelper = *vfhelper
sshDir := filepath.Join(homedir.Get(), ".ssh") m.IdentityPath = util.GetIdentityPath(m.Name)
m.IdentityPath = filepath.Join(sshDir, m.Name)
m.Rootful = opts.Rootful m.Rootful = opts.Rootful
m.RemoteUsername = opts.Username m.RemoteUsername = opts.Username
@ -142,7 +141,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) {
// TODO localhost needs to be restored here // TODO localhost needs to be restored here
uri := machine.SSHRemoteConnection.MakeSSHURL("192.168.64.2", fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername) uri := machine.SSHRemoteConnection.MakeSSHURL("192.168.64.2", fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL("localhost", "/run/podman/podman.sock", strconv.Itoa(m.Port), "root") uriRoot := machine.SSHRemoteConnection.MakeSSHURL("localhost", "/run/podman/podman.sock", strconv.Itoa(m.Port), "root")
identity := filepath.Join(sshDir, m.Name) identity := m.IdentityPath
uris := []url.URL{uri, uriRoot} uris := []url.URL{uri, uriRoot}
names := []string{m.Name, m.Name + "-root"} names := []string{m.Name, m.Name + "-root"}

View File

@ -18,8 +18,8 @@ import (
"github.com/containers/common/pkg/config" "github.com/containers/common/pkg/config"
"github.com/containers/libhvee/pkg/hypervctl" "github.com/containers/libhvee/pkg/hypervctl"
"github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils" "github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils" "github.com/containers/storage/pkg/ioutils"
"github.com/docker/go-units" "github.com/docker/go-units"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -91,9 +91,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
} }
m.NetworkHVSock = *networkHVSock m.NetworkHVSock = *networkHVSock
m.ReadyHVSock = *eventHVSocket m.ReadyHVSock = *eventHVSocket
m.IdentityPath = util.GetIdentityPath(m.Name)
sshDir := filepath.Join(homedir.Get(), ".ssh")
m.IdentityPath = filepath.Join(sshDir, m.Name)
// TODO This needs to be fixed in c-common // TODO This needs to be fixed in c-common
m.RemoteUsername = "core" m.RemoteUsername = "core"
@ -111,7 +109,6 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
if len(opts.IgnitionPath) < 1 { if len(opts.IgnitionPath) < 1 {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername) uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(m.Port), "root") uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(m.Port), "root")
identity := filepath.Join(sshDir, m.Name)
uris := []url.URL{uri, uriRoot} uris := []url.URL{uri, uriRoot}
names := []string{m.Name, m.Name + "-root"} names := []string{m.Name, m.Name + "-root"}
@ -123,7 +120,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
} }
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { if err := machine.AddConnection(&uris[i], names[i], m.IdentityPath, opts.IsDefault && i == 0); err != nil {
return false, err return false, err
} }
} }

View File

@ -33,18 +33,16 @@ func CreateSSHKeys(writeLocation string) (string, error) {
return strings.TrimSuffix(string(b), "\n"), nil return strings.TrimSuffix(string(b), "\n"), nil
} }
func CreateSSHKeysPrefix(dir string, file string, passThru bool, skipExisting bool, prefix ...string) (string, error) { func CreateSSHKeysPrefix(identityPath string, passThru bool, skipExisting bool, prefix ...string) (string, error) {
location := filepath.Join(dir, file) _, e := os.Stat(identityPath)
_, e := os.Stat(location)
if !skipExisting || errors.Is(e, os.ErrNotExist) { if !skipExisting || errors.Is(e, os.ErrNotExist) {
if err := generatekeysPrefix(dir, file, passThru, prefix...); err != nil { if err := generatekeysPrefix(identityPath, passThru, prefix...); err != nil {
return "", err return "", err
} }
} else { } else {
fmt.Println("Keys already exist, reusing") fmt.Println("Keys already exist, reusing")
} }
b, err := os.ReadFile(filepath.Join(dir, file) + ".pub") b, err := os.ReadFile(identityPath + ".pub")
if err != nil { if err != nil {
return "", err return "", err
} }
@ -74,7 +72,14 @@ func generatekeys(writeLocation string) error {
} }
// generatekeys creates an ed25519 set of keys // generatekeys creates an ed25519 set of keys
func generatekeysPrefix(dir string, file string, passThru bool, prefix ...string) error { func generatekeysPrefix(identityPath string, passThru bool, prefix ...string) error {
dir := filepath.Dir(identityPath)
file := filepath.Base(identityPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("could not create ssh directory: %w", err)
}
args := append([]string{}, prefix[1:]...) args := append([]string{}, prefix[1:]...)
args = append(args, sshCommand...) args = append(args, sshCommand...)
args = append(args, file) args = append(args, file)

View File

@ -27,8 +27,8 @@ import (
"github.com/containers/common/pkg/config" "github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/rootless" "github.com/containers/podman/v4/pkg/rootless"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils" "github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils" "github.com/containers/storage/pkg/ioutils"
"github.com/digitalocean/go-qemu/qmp" "github.com/digitalocean/go-qemu/qmp"
"github.com/docker/go-units" "github.com/docker/go-units"
@ -242,8 +242,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
var ( var (
key string key string
) )
sshDir := filepath.Join(homedir.Get(), ".ssh") v.IdentityPath = util.GetIdentityPath(v.Name)
v.IdentityPath = filepath.Join(sshDir, v.Name)
v.Rootful = opts.Rootful v.Rootful = opts.Rootful
switch opts.ImagePath { switch opts.ImagePath {
@ -320,7 +319,6 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
if len(opts.IgnitionPath) < 1 { if len(opts.IgnitionPath) < 1 {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", v.UID), strconv.Itoa(v.Port), v.RemoteUsername) uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", v.UID), strconv.Itoa(v.Port), v.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(v.Port), "root") uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(v.Port), "root")
identity := filepath.Join(sshDir, v.Name)
uris := []url.URL{uri, uriRoot} uris := []url.URL{uri, uriRoot}
names := []string{v.Name, v.Name + "-root"} names := []string{v.Name, v.Name + "-root"}
@ -332,7 +330,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
} }
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil {
return false, err return false, err
} }
} }

View File

@ -21,6 +21,7 @@ import (
"github.com/containers/common/pkg/config" "github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine" "github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/machine/wsl/wutil" "github.com/containers/podman/v4/pkg/machine/wsl/wutil"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils" "github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir" "github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils" "github.com/containers/storage/pkg/ioutils"
@ -406,9 +407,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
} }
_ = setupWslProxyEnv() _ = setupWslProxyEnv()
homeDir := homedir.Get() v.IdentityPath = util.GetIdentityPath(v.Name)
sshDir := filepath.Join(homeDir, ".ssh")
v.IdentityPath = filepath.Join(sshDir, v.Name)
v.Rootful = opts.Rootful v.Rootful = opts.Rootful
v.Version = currentMachineVersion v.Version = currentMachineVersion
@ -438,7 +437,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
return false, err return false, err
} }
if err = createKeys(v, dist, sshDir); err != nil { if err = createKeys(v, dist); err != nil {
return false, err return false, err
} }
@ -449,7 +448,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
return false, err return false, err
} }
if err := setupConnections(v, opts, sshDir); err != nil { if err := setupConnections(v, opts); err != nil {
return false, err return false, err
} }
@ -502,10 +501,9 @@ func (v *MachineVM) writeConfig() error {
return nil return nil
} }
func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) error { func setupConnections(v *MachineVM, opts machine.InitOptions) error {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootlessSock, strconv.Itoa(v.Port), v.RemoteUsername) uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootlessSock, strconv.Itoa(v.Port), v.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootfulSock, strconv.Itoa(v.Port), "root") uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootfulSock, strconv.Itoa(v.Port), "root")
identity := filepath.Join(sshDir, v.Name)
uris := []url.URL{uri, uriRoot} uris := []url.URL{uri, uriRoot}
names := []string{v.Name, v.Name + "-root"} names := []string{v.Name, v.Name + "-root"}
@ -517,7 +515,7 @@ func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) err
} }
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil { if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil {
return err return err
} }
} }
@ -551,18 +549,14 @@ func provisionWSLDist(name string, imagePath string, prompt string) (string, err
return dist, nil return dist, nil
} }
func createKeys(v *MachineVM, dist string, sshDir string) error { func createKeys(v *MachineVM, dist string) error {
user := v.RemoteUsername user := v.RemoteUsername
if err := os.MkdirAll(sshDir, 0700); err != nil {
return fmt.Errorf("could not create ssh directory: %w", err)
}
if err := terminateDist(dist); err != nil { if err := terminateDist(dist); err != nil {
return fmt.Errorf("could not cycle WSL dist: %w", err) return fmt.Errorf("could not cycle WSL dist: %w", err)
} }
key, err := wslCreateKeys(sshDir, v.Name, dist) key, err := wslCreateKeys(v.IdentityPath, dist)
if err != nil { if err != nil {
return fmt.Errorf("could not create ssh keys: %w", err) return fmt.Errorf("could not create ssh keys: %w", err)
} }
@ -972,8 +966,8 @@ func wslPipe(input string, dist string, arg ...string) error {
return pipeCmdPassThrough("wsl", input, newArgs...) return pipeCmdPassThrough("wsl", input, newArgs...)
} }
func wslCreateKeys(sshDir string, name string, dist string) (string, error) { func wslCreateKeys(identityPath string, dist string) (string, error) {
return machine.CreateSSHKeysPrefix(sshDir, name, true, true, "wsl", "-u", "root", "-d", dist) return machine.CreateSSHKeysPrefix(identityPath, true, true, "wsl", "-u", "root", "-d", dist)
} }
func runCmdPassThrough(name string, arg ...string) error { func runCmdPassThrough(name string, arg ...string) error {

View File

@ -26,6 +26,7 @@ import (
"github.com/containers/podman/v4/pkg/rootless" "github.com/containers/podman/v4/pkg/rootless"
"github.com/containers/podman/v4/pkg/signal" "github.com/containers/podman/v4/pkg/signal"
"github.com/containers/storage/pkg/directory" "github.com/containers/storage/pkg/directory"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/idtools" "github.com/containers/storage/pkg/idtools"
stypes "github.com/containers/storage/types" stypes "github.com/containers/storage/types"
securejoin "github.com/cyphar/filepath-securejoin" securejoin "github.com/cyphar/filepath-securejoin"
@ -473,17 +474,8 @@ func ExitCode(err error) int {
return 126 return 126
} }
// HomeDir returns the home directory for the current user. func GetIdentityPath(name string) string {
func HomeDir() (string, error) { return filepath.Join(homedir.Get(), ".ssh", name)
home := os.Getenv("HOME")
if home == "" {
usr, err := user.LookupId(fmt.Sprintf("%d", rootless.GetRootlessUID()))
if err != nil {
return "", fmt.Errorf("unable to resolve HOME directory: %w", err)
}
home = usr.HomeDir
}
return home, nil
} }
func Tmpdir() string { func Tmpdir() string {

View File

@ -2,9 +2,11 @@ package util
import ( import (
"fmt" "fmt"
"path/filepath"
"testing" "testing"
"time" "time"
"github.com/containers/storage/pkg/homedir"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -56,6 +58,12 @@ func TestValidateSysctlBadSysctlWithExtraSpaces(t *testing.T) {
assert.Equal(t, err.Error(), fmt.Sprintf(expectedError, strSlice2[1])) assert.Equal(t, err.Error(), fmt.Sprintf(expectedError, strSlice2[1]))
} }
func TestGetIdentityPath(t *testing.T) {
name := "p-test"
identityPath := GetIdentityPath(name)
assert.Equal(t, identityPath, filepath.Join(homedir.Get(), ".ssh", name))
}
func TestCoresToPeriodAndQuota(t *testing.T) { func TestCoresToPeriodAndQuota(t *testing.T) {
cores := 1.0 cores := 1.0
expectedPeriod := DefaultCPUPeriod expectedPeriod := DefaultCPUPeriod