refactor: improve get ssh path duplicate code

Signed-off-by: Black-Hole1 <bh@bugs.cc>
This commit is contained in:
Black-Hole1
2023-05-31 16:27:19 +08:00
parent e91f6f16bf
commit c7a8d29f12
7 changed files with 42 additions and 49 deletions

View File

@ -17,7 +17,7 @@ import (
"github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/podman/v4/pkg/util"
"github.com/docker/go-units"
"github.com/sirupsen/logrus"
)
@ -123,8 +123,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) {
}
m.VfkitHelper = *vfhelper
sshDir := filepath.Join(homedir.Get(), ".ssh")
m.IdentityPath = filepath.Join(sshDir, m.Name)
m.IdentityPath = util.GetIdentityPath(m.Name)
m.Rootful = opts.Rootful
m.RemoteUsername = opts.Username
@ -142,7 +141,7 @@ func (m *MacMachine) Init(opts machine.InitOptions) (bool, error) {
// TODO localhost needs to be restored here
uri := machine.SSHRemoteConnection.MakeSSHURL("192.168.64.2", fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL("localhost", "/run/podman/podman.sock", strconv.Itoa(m.Port), "root")
identity := filepath.Join(sshDir, m.Name)
identity := m.IdentityPath
uris := []url.URL{uri, uriRoot}
names := []string{m.Name, m.Name + "-root"}

View File

@ -18,8 +18,8 @@ import (
"github.com/containers/common/pkg/config"
"github.com/containers/libhvee/pkg/hypervctl"
"github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils"
"github.com/docker/go-units"
"github.com/sirupsen/logrus"
@ -91,9 +91,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
}
m.NetworkHVSock = *networkHVSock
m.ReadyHVSock = *eventHVSocket
sshDir := filepath.Join(homedir.Get(), ".ssh")
m.IdentityPath = filepath.Join(sshDir, m.Name)
m.IdentityPath = util.GetIdentityPath(m.Name)
// TODO This needs to be fixed in c-common
m.RemoteUsername = "core"
@ -111,7 +109,6 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
if len(opts.IgnitionPath) < 1 {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", m.UID), strconv.Itoa(m.Port), m.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(m.Port), "root")
identity := filepath.Join(sshDir, m.Name)
uris := []url.URL{uri, uriRoot}
names := []string{m.Name, m.Name + "-root"}
@ -123,7 +120,7 @@ func (m *HyperVMachine) Init(opts machine.InitOptions) (bool, error) {
}
for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil {
if err := machine.AddConnection(&uris[i], names[i], m.IdentityPath, opts.IsDefault && i == 0); err != nil {
return false, err
}
}

View File

@ -33,18 +33,16 @@ func CreateSSHKeys(writeLocation string) (string, error) {
return strings.TrimSuffix(string(b), "\n"), nil
}
func CreateSSHKeysPrefix(dir string, file string, passThru bool, skipExisting bool, prefix ...string) (string, error) {
location := filepath.Join(dir, file)
_, e := os.Stat(location)
func CreateSSHKeysPrefix(identityPath string, passThru bool, skipExisting bool, prefix ...string) (string, error) {
_, e := os.Stat(identityPath)
if !skipExisting || errors.Is(e, os.ErrNotExist) {
if err := generatekeysPrefix(dir, file, passThru, prefix...); err != nil {
if err := generatekeysPrefix(identityPath, passThru, prefix...); err != nil {
return "", err
}
} else {
fmt.Println("Keys already exist, reusing")
}
b, err := os.ReadFile(filepath.Join(dir, file) + ".pub")
b, err := os.ReadFile(identityPath + ".pub")
if err != nil {
return "", err
}
@ -74,7 +72,14 @@ func generatekeys(writeLocation string) error {
}
// generatekeys creates an ed25519 set of keys
func generatekeysPrefix(dir string, file string, passThru bool, prefix ...string) error {
func generatekeysPrefix(identityPath string, passThru bool, prefix ...string) error {
dir := filepath.Dir(identityPath)
file := filepath.Base(identityPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("could not create ssh directory: %w", err)
}
args := append([]string{}, prefix[1:]...)
args = append(args, sshCommand...)
args = append(args, file)

View File

@ -27,8 +27,8 @@ import (
"github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/rootless"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils"
"github.com/digitalocean/go-qemu/qmp"
"github.com/docker/go-units"
@ -242,8 +242,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
var (
key string
)
sshDir := filepath.Join(homedir.Get(), ".ssh")
v.IdentityPath = filepath.Join(sshDir, v.Name)
v.IdentityPath = util.GetIdentityPath(v.Name)
v.Rootful = opts.Rootful
switch opts.ImagePath {
@ -320,7 +319,6 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
if len(opts.IgnitionPath) < 1 {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, fmt.Sprintf("/run/user/%d/podman/podman.sock", v.UID), strconv.Itoa(v.Port), v.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, "/run/podman/podman.sock", strconv.Itoa(v.Port), "root")
identity := filepath.Join(sshDir, v.Name)
uris := []url.URL{uri, uriRoot}
names := []string{v.Name, v.Name + "-root"}
@ -332,7 +330,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
}
for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil {
if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil {
return false, err
}
}

View File

@ -21,6 +21,7 @@ import (
"github.com/containers/common/pkg/config"
"github.com/containers/podman/v4/pkg/machine"
"github.com/containers/podman/v4/pkg/machine/wsl/wutil"
"github.com/containers/podman/v4/pkg/util"
"github.com/containers/podman/v4/utils"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/ioutils"
@ -406,9 +407,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
}
_ = setupWslProxyEnv()
homeDir := homedir.Get()
sshDir := filepath.Join(homeDir, ".ssh")
v.IdentityPath = filepath.Join(sshDir, v.Name)
v.IdentityPath = util.GetIdentityPath(v.Name)
v.Rootful = opts.Rootful
v.Version = currentMachineVersion
@ -438,7 +437,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
return false, err
}
if err = createKeys(v, dist, sshDir); err != nil {
if err = createKeys(v, dist); err != nil {
return false, err
}
@ -449,7 +448,7 @@ func (v *MachineVM) Init(opts machine.InitOptions) (bool, error) {
return false, err
}
if err := setupConnections(v, opts, sshDir); err != nil {
if err := setupConnections(v, opts); err != nil {
return false, err
}
@ -502,10 +501,9 @@ func (v *MachineVM) writeConfig() error {
return nil
}
func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) error {
func setupConnections(v *MachineVM, opts machine.InitOptions) error {
uri := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootlessSock, strconv.Itoa(v.Port), v.RemoteUsername)
uriRoot := machine.SSHRemoteConnection.MakeSSHURL(machine.LocalhostIP, rootfulSock, strconv.Itoa(v.Port), "root")
identity := filepath.Join(sshDir, v.Name)
uris := []url.URL{uri, uriRoot}
names := []string{v.Name, v.Name + "-root"}
@ -517,7 +515,7 @@ func setupConnections(v *MachineVM, opts machine.InitOptions, sshDir string) err
}
for i := 0; i < 2; i++ {
if err := machine.AddConnection(&uris[i], names[i], identity, opts.IsDefault && i == 0); err != nil {
if err := machine.AddConnection(&uris[i], names[i], v.IdentityPath, opts.IsDefault && i == 0); err != nil {
return err
}
}
@ -551,18 +549,14 @@ func provisionWSLDist(name string, imagePath string, prompt string) (string, err
return dist, nil
}
func createKeys(v *MachineVM, dist string, sshDir string) error {
func createKeys(v *MachineVM, dist string) error {
user := v.RemoteUsername
if err := os.MkdirAll(sshDir, 0700); err != nil {
return fmt.Errorf("could not create ssh directory: %w", err)
}
if err := terminateDist(dist); err != nil {
return fmt.Errorf("could not cycle WSL dist: %w", err)
}
key, err := wslCreateKeys(sshDir, v.Name, dist)
key, err := wslCreateKeys(v.IdentityPath, dist)
if err != nil {
return fmt.Errorf("could not create ssh keys: %w", err)
}
@ -972,8 +966,8 @@ func wslPipe(input string, dist string, arg ...string) error {
return pipeCmdPassThrough("wsl", input, newArgs...)
}
func wslCreateKeys(sshDir string, name string, dist string) (string, error) {
return machine.CreateSSHKeysPrefix(sshDir, name, true, true, "wsl", "-u", "root", "-d", dist)
func wslCreateKeys(identityPath string, dist string) (string, error) {
return machine.CreateSSHKeysPrefix(identityPath, true, true, "wsl", "-u", "root", "-d", dist)
}
func runCmdPassThrough(name string, arg ...string) error {

View File

@ -26,6 +26,7 @@ import (
"github.com/containers/podman/v4/pkg/rootless"
"github.com/containers/podman/v4/pkg/signal"
"github.com/containers/storage/pkg/directory"
"github.com/containers/storage/pkg/homedir"
"github.com/containers/storage/pkg/idtools"
stypes "github.com/containers/storage/types"
securejoin "github.com/cyphar/filepath-securejoin"
@ -473,17 +474,8 @@ func ExitCode(err error) int {
return 126
}
// HomeDir returns the home directory for the current user.
func HomeDir() (string, error) {
home := os.Getenv("HOME")
if home == "" {
usr, err := user.LookupId(fmt.Sprintf("%d", rootless.GetRootlessUID()))
if err != nil {
return "", fmt.Errorf("unable to resolve HOME directory: %w", err)
}
home = usr.HomeDir
}
return home, nil
func GetIdentityPath(name string) string {
return filepath.Join(homedir.Get(), ".ssh", name)
}
func Tmpdir() string {

View File

@ -2,9 +2,11 @@ package util
import (
"fmt"
"path/filepath"
"testing"
"time"
"github.com/containers/storage/pkg/homedir"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/assert"
)
@ -56,6 +58,12 @@ func TestValidateSysctlBadSysctlWithExtraSpaces(t *testing.T) {
assert.Equal(t, err.Error(), fmt.Sprintf(expectedError, strSlice2[1]))
}
func TestGetIdentityPath(t *testing.T) {
name := "p-test"
identityPath := GetIdentityPath(name)
assert.Equal(t, identityPath, filepath.Join(homedir.Get(), ".ssh", name))
}
func TestCoresToPeriodAndQuota(t *testing.T) {
cores := 1.0
expectedPeriod := DefaultCPUPeriod