diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index 96af153041..fbdac6aebf 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -1,16 +1,22 @@ package shell import ( + "errors" "fmt" "os" "strings" + "github.com/shirou/gopsutil/v4/process" + "github.com/spf13/cast" + crcos "github.com/crc-org/crc/v2/pkg/os" ) var ( CommandRunner = crcos.NewLocalCommandRunner() WindowsSubsystemLinuxKernelMetadataFile = "/proc/version" + ErrUnknownShell = errors.New("Error: Unknown shell") + currentProcessSupplier = createCurrentProcess ) type Config struct { @@ -20,6 +26,20 @@ type Config struct { PathSuffix string } +// AbstractProcess is an interface created to abstract operations of the gopsutil library +// It is created so that we can override the behavior while writing unit tests by providing +// a mock implementation. +type AbstractProcess interface { + Name() (string, error) + Parent() (AbstractProcess, error) +} + +// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's +// process.Process object. This implementation is used in production code. +type RealProcess struct { + *process.Process +} + func GetShell(userShell string) (string, error) { if userShell != "" { if !isSupportedShell(userShell) { @@ -151,3 +171,50 @@ func IsWindowsSubsystemLinux() bool { } return false } + +func (p *RealProcess) Parent() (AbstractProcess, error) { + parentProcess, err := p.Process.Parent() + if err != nil { + return nil, err + } + return &RealProcess{parentProcess}, nil +} + +func createCurrentProcess() AbstractProcess { + currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) + if err != nil { + return nil + } + return &RealProcess{currentProcess} +} + +// detectShellByCheckingProcessTree attempts to identify the shell being used by +// examining the process tree starting from the given process ID. This function +// traverses up to ProcessDepthLimit levels up the process hierarchy. +// Parameters: +// - pid (int): The process ID to start checking from. +// +// Returns: +// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); +// otherwise, an empty string is returned if no matching shell is detected +// or an error occurs during the process tree traversal. +// +// Examples: +// +// shellName := detectShellByCheckingProcessTree(1234) +func detectShellByCheckingProcessTree(p AbstractProcess) string { + for p != nil { + processName, err := p.Name() + if err != nil { + return "" + } + if processName == "zsh" || processName == "bash" || processName == "fish" { + return processName + } + p, err = p.Parent() + if err != nil { + return "" + } + } + return "" +} diff --git a/pkg/os/shell/shell_darwin.go b/pkg/os/shell/shell_darwin.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_darwin.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_linux.go b/pkg/os/shell/shell_linux.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_linux.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_test.go b/pkg/os/shell/shell_test.go index a707555663..d0ba51133e 100644 --- a/pkg/os/shell/shell_test.go +++ b/pkg/os/shell/shell_test.go @@ -1,6 +1,7 @@ package shell import ( + "errors" "os" "path/filepath" "testing" @@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn } +// MockedProcess is a mock implementation of AbstractProcess for testing purposes. +type MockedProcess struct { + name string + parent *MockedProcess + nameGetFails bool + parentGetFails bool +} + +func (m MockedProcess) Parent() (AbstractProcess, error) { + if m.parentGetFails || m.parent == nil { + return nil, errors.New("failed to get the pid") + } + return m.parent, nil +} + +func (m MockedProcess) Name() (string, error) { + if m.nameGetFails { + return "", errors.New("failed to get the name") + } + return m.name, nil +} + func TestGetPathEnvString(t *testing.T) { tests := []struct { name string @@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) { assert.Equal(t, "wsl", mockCommandExecutor.commandName) assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs) } + +func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { + if len(processes) == 0 { + return nil + } + head := &processes[0] + current := head + for i := 1; i < len(processes); i++ { + current.parent = &processes[i] + current = current.parent + } + return head +} diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index 649f3882f5..614060a76c 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -4,50 +4,10 @@ package shell import ( - "errors" "fmt" - "os" "path/filepath" - - "github.com/shirou/gopsutil/v4/process" - "github.com/spf13/cast" -) - -var ( - ErrUnknownShell = errors.New("Error: Unknown shell") - currentProcessSupplier = createCurrentProcess ) -// AbstractProcess is an interface created to abstract operations of the gopsutil library -// It is created so that we can override the behavior while writing unit tests by providing -// a mock implementation. -type AbstractProcess interface { - Name() (string, error) - Parent() (AbstractProcess, error) -} - -// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's -// process.Process object. This implementation is used in production code. -type RealProcess struct { - *process.Process -} - -func (p *RealProcess) Parent() (AbstractProcess, error) { - parentProcess, err := p.Process.Parent() - if err != nil { - return nil, err - } - return &RealProcess{parentProcess}, nil -} - -func createCurrentProcess() AbstractProcess { - currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) - if err != nil { - return nil - } - return &RealProcess{currentProcess} -} - // detect detects user's current shell. func detect() (string, error) { detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier()) @@ -58,34 +18,3 @@ func detect() (string, error) { return filepath.Base(detectedShell), nil } - -// detectShellByCheckingProcessTree attempts to identify the shell being used by -// examining the process tree starting from the given process ID. This function -// traverses up to ProcessDepthLimit levels up the process hierarchy. -// Parameters: -// - pid (int): The process ID to start checking from. -// -// Returns: -// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); -// otherwise, an empty string is returned if no matching shell is detected -// or an error occurs during the process tree traversal. -// -// Examples: -// -// shellName := detectShellByCheckingProcessTree(1234) -func detectShellByCheckingProcessTree(p AbstractProcess) string { - for p != nil { - processName, err := p.Name() - if err != nil { - return "" - } - if processName == "zsh" || processName == "bash" || processName == "fish" { - return processName - } - p, err = p.Parent() - if err != nil { - return "" - } - } - return "" -} diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index 524825d134..a9c24cc086 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -5,35 +5,12 @@ package shell import ( "bytes" - "errors" "os" "testing" "github.com/stretchr/testify/assert" ) -// MockedProcess is a mock implementation of AbstractProcess for testing purposes. -type MockedProcess struct { - name string - parent *MockedProcess - nameGetFails bool - parentGetFails bool -} - -func (m MockedProcess) Parent() (AbstractProcess, error) { - if m.parentGetFails || m.parent == nil { - return nil, errors.New("failed to get the pid") - } - return m.parent, nil -} - -func (m MockedProcess) Name() (string, error) { - if m.nameGetFails { - return "", errors.New("failed to get the name") - } - return m.name, nil -} - func TestUnknownShell(t *testing.T) { tests := []struct { name string @@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) { assert.NoError(t, err) assert.Greater(t, len(currentProcessName), 0) } - -func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { - if len(processes) == 0 { - return nil - } - head := &processes[0] - current := head - for i := 1; i < len(processes); i++ { - current.parent = &processes[i] - current = current.parent - } - return head -}