From 9f7677bb2cd3274a29820dddd42f4e9e344fb7e5 Mon Sep 17 00:00:00 2001 From: Rohan Kumar Date: Sat, 1 Feb 2025 01:00:58 +0530 Subject: [PATCH] refactor (shell) : detect shell on windows using gopsutil too (#4588) As a follow up to #4572, updating shell detection logic on windows to also rely on gopsutil library to detect the shell name by inspecting parent processes of `crc.exe` process. Signed-off-by: Rohan Kumar --- pkg/os/shell/shell.go | 67 +++++++++++++ pkg/os/shell/shell_darwin.go | 5 - pkg/os/shell/shell_linux.go | 5 - pkg/os/shell/shell_test.go | 36 +++++++ pkg/os/shell/shell_unix.go | 69 +------------- pkg/os/shell/shell_unix_test.go | 36 ------- pkg/os/shell/shell_windows.go | 83 ++-------------- pkg/os/shell/shell_windows_test.go | 146 +++++++++++++++++++++++------ pkg/os/util.go | 9 ++ pkg/os/util_test.go | 22 +++++ 10 files changed, 257 insertions(+), 221 deletions(-) delete mode 100644 pkg/os/shell/shell_darwin.go delete mode 100644 pkg/os/shell/shell_linux.go diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index 96af153041..6d36dfab98 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -1,16 +1,22 @@ package shell import ( + "errors" "fmt" "os" "strings" + "github.com/shirou/gopsutil/v4/process" + "github.com/spf13/cast" + crcos "github.com/crc-org/crc/v2/pkg/os" ) var ( CommandRunner = crcos.NewLocalCommandRunner() WindowsSubsystemLinuxKernelMetadataFile = "/proc/version" + ErrUnknownShell = errors.New("Error: Unknown shell") + currentProcessSupplier = createCurrentProcess ) type Config struct { @@ -20,6 +26,20 @@ type Config struct { PathSuffix string } +// AbstractProcess is an interface created to abstract operations of the gopsutil library +// It is created so that we can override the behavior while writing unit tests by providing +// a mock implementation. +type AbstractProcess interface { + Name() (string, error) + Parent() (AbstractProcess, error) +} + +// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's +// process.Process object. This implementation is used in production code. +type RealProcess struct { + *process.Process +} + func GetShell(userShell string) (string, error) { if userShell != "" { if !isSupportedShell(userShell) { @@ -151,3 +171,50 @@ func IsWindowsSubsystemLinux() bool { } return false } + +func (p *RealProcess) Parent() (AbstractProcess, error) { + parentProcess, err := p.Process.Parent() + if err != nil { + return nil, err + } + return &RealProcess{parentProcess}, nil +} + +func createCurrentProcess() AbstractProcess { + currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) + if err != nil { + return nil + } + return &RealProcess{currentProcess} +} + +// detectShellByCheckingProcessTree attempts to identify the shell being used by +// examining the process tree starting from the given process ID. This function +// traverses up to ProcessDepthLimit levels up the process hierarchy. +// Parameters: +// - pid (int): The process ID to start checking from. +// +// Returns: +// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); +// otherwise, an empty string is returned if no matching shell is detected +// or an error occurs during the process tree traversal. +// +// Examples: +// +// shellName := detectShellByCheckingProcessTree(1234) +func detectShellByCheckingProcessTree(p AbstractProcess) string { + for p != nil { + processName, err := p.Name() + if err != nil { + return "" + } + if crcos.IsPresentInList(supportedShell, processName) { + return processName + } + p, err = p.Parent() + if err != nil { + return "" + } + } + return "" +} diff --git a/pkg/os/shell/shell_darwin.go b/pkg/os/shell/shell_darwin.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_darwin.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_linux.go b/pkg/os/shell/shell_linux.go deleted file mode 100644 index a7038da061..0000000000 --- a/pkg/os/shell/shell_linux.go +++ /dev/null @@ -1,5 +0,0 @@ -package shell - -var ( - supportedShell = []string{"bash", "zsh", "fish"} -) diff --git a/pkg/os/shell/shell_test.go b/pkg/os/shell/shell_test.go index a707555663..d0ba51133e 100644 --- a/pkg/os/shell/shell_test.go +++ b/pkg/os/shell/shell_test.go @@ -1,6 +1,7 @@ package shell import ( + "errors" "os" "path/filepath" "testing" @@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn } +// MockedProcess is a mock implementation of AbstractProcess for testing purposes. +type MockedProcess struct { + name string + parent *MockedProcess + nameGetFails bool + parentGetFails bool +} + +func (m MockedProcess) Parent() (AbstractProcess, error) { + if m.parentGetFails || m.parent == nil { + return nil, errors.New("failed to get the pid") + } + return m.parent, nil +} + +func (m MockedProcess) Name() (string, error) { + if m.nameGetFails { + return "", errors.New("failed to get the name") + } + return m.name, nil +} + func TestGetPathEnvString(t *testing.T) { tests := []struct { name string @@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) { assert.Equal(t, "wsl", mockCommandExecutor.commandName) assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs) } + +func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { + if len(processes) == 0 { + return nil + } + head := &processes[0] + current := head + for i := 1; i < len(processes); i++ { + current.parent = &processes[i] + current = current.parent + } + return head +} diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index 649f3882f5..43df1f7b80 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -4,50 +4,14 @@ package shell import ( - "errors" "fmt" - "os" "path/filepath" - - "github.com/shirou/gopsutil/v4/process" - "github.com/spf13/cast" ) var ( - ErrUnknownShell = errors.New("Error: Unknown shell") - currentProcessSupplier = createCurrentProcess + supportedShell = []string{"bash", "zsh", "fish"} ) -// AbstractProcess is an interface created to abstract operations of the gopsutil library -// It is created so that we can override the behavior while writing unit tests by providing -// a mock implementation. -type AbstractProcess interface { - Name() (string, error) - Parent() (AbstractProcess, error) -} - -// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's -// process.Process object. This implementation is used in production code. -type RealProcess struct { - *process.Process -} - -func (p *RealProcess) Parent() (AbstractProcess, error) { - parentProcess, err := p.Process.Parent() - if err != nil { - return nil, err - } - return &RealProcess{parentProcess}, nil -} - -func createCurrentProcess() AbstractProcess { - currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid())) - if err != nil { - return nil - } - return &RealProcess{currentProcess} -} - // detect detects user's current shell. func detect() (string, error) { detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier()) @@ -58,34 +22,3 @@ func detect() (string, error) { return filepath.Base(detectedShell), nil } - -// detectShellByCheckingProcessTree attempts to identify the shell being used by -// examining the process tree starting from the given process ID. This function -// traverses up to ProcessDepthLimit levels up the process hierarchy. -// Parameters: -// - pid (int): The process ID to start checking from. -// -// Returns: -// - string: The name of the shell if found (e.g., "zsh", "bash", "fish"); -// otherwise, an empty string is returned if no matching shell is detected -// or an error occurs during the process tree traversal. -// -// Examples: -// -// shellName := detectShellByCheckingProcessTree(1234) -func detectShellByCheckingProcessTree(p AbstractProcess) string { - for p != nil { - processName, err := p.Name() - if err != nil { - return "" - } - if processName == "zsh" || processName == "bash" || processName == "fish" { - return processName - } - p, err = p.Parent() - if err != nil { - return "" - } - } - return "" -} diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index 524825d134..a9c24cc086 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -5,35 +5,12 @@ package shell import ( "bytes" - "errors" "os" "testing" "github.com/stretchr/testify/assert" ) -// MockedProcess is a mock implementation of AbstractProcess for testing purposes. -type MockedProcess struct { - name string - parent *MockedProcess - nameGetFails bool - parentGetFails bool -} - -func (m MockedProcess) Parent() (AbstractProcess, error) { - if m.parentGetFails || m.parent == nil { - return nil, errors.New("failed to get the pid") - } - return m.parent, nil -} - -func (m MockedProcess) Name() (string, error) { - if m.nameGetFails { - return "", errors.New("failed to get the name") - } - return m.name, nil -} - func TestUnknownShell(t *testing.T) { tests := []struct { name string @@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) { assert.NoError(t, err) assert.Greater(t, len(currentProcessName), 0) } - -func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess { - if len(processes) == 0 { - return nil - } - head := &processes[0] - current := head - for i := 1; i < len(processes); i++ { - current.parent = &processes[i] - current = current.parent - } - return head -} diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 52f26654b3..fd15000ef4 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -1,64 +1,19 @@ package shell import ( - "fmt" - "math" - "os" - "path/filepath" "sort" "strconv" "strings" - "syscall" - "unsafe" + + crcos "github.com/crc-org/crc/v2/pkg/os" "github.com/crc-org/crc/v2/pkg/crc/logging" ) var ( - supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"} + supportedShell = []string{"cmd", "cmd.exe", "powershell", "powershell.exe", "wsl.exe", "bash.exe", "bash", "zsh", "fish"} ) -// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go -func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) { - snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0) - if err != nil { - return nil, err - } - defer func() { - _ = syscall.CloseHandle(syscall.Handle(snapshot)) - }() - - var processEntry syscall.ProcessEntry32 - processEntry.Size = uint32(unsafe.Sizeof(processEntry)) - err = syscall.Process32First(snapshot, &processEntry) - if err != nil { - return nil, err - } - - for { - if processEntry.ProcessID == pid { - pe = &processEntry - return - } - - err = syscall.Process32Next(snapshot, &processEntry) - if err != nil { - return nil, err - } - } -} - -// getNameAndItsPpid returns the exe file name its parent process id. -func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) { - pe, err := getProcessEntry(pid) - if err != nil { - return "", 0, err - } - - name := syscall.UTF16ToString(pe.ExeFile[:]) - return name, pe.ParentProcessID, nil -} - func shellType(shell string, defaultShell string) string { switch { case strings.Contains(strings.ToLower(shell), "powershell"): @@ -69,7 +24,7 @@ func shellType(shell string, defaultShell string) string { return "cmd" case strings.Contains(strings.ToLower(shell), "wsl"): return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="}) - case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"): + case strings.Contains(strings.ToLower(shell), "bash"): return "bash" default: return defaultShell @@ -77,31 +32,7 @@ func shellType(shell string, defaultShell string) string { } func detect() (string, error) { - shell := os.Getenv("SHELL") - - if shell == "" { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - return "", fmt.Errorf("integer overflow for pid: %v", pid) - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - if err != nil { - return "cmd", err // defaulting to cmd - } - shell = shellType(shell, "") - if shell == "" { - shell, _, err := getNameAndItsPpid(shellppid) - if err != nil { - return "cmd", err // defaulting to cmd - } - return shellType(shell, "cmd"), nil - } - return shell, nil - } - - if os.Getenv("__fish_bin_dir") != "" { - return "fish", nil - } + shell := detectShellByCheckingProcessTree(currentProcessSupplier()) return shellType(shell, "cmd"), nil } @@ -163,9 +94,7 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { lines := strings.Split(psCommandOutput, "\n") for _, line := range lines { lineParts := strings.Split(strings.TrimSpace(line), " ") - if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || - strings.Contains(lineParts[1], "bash") || - strings.Contains(lineParts[1], "fish")) { + if len(lineParts) == 2 && crcos.IsPresentInList(supportedShell, lineParts[1]) { parsedProcessID, err := strconv.Atoi(lineParts[0]) if err == nil { processOutputs = append(processOutputs, ProcessOutput{ diff --git a/pkg/os/shell/shell_windows_test.go b/pkg/os/shell/shell_windows_test.go index 06fecdad79..dc7f5f9baf 100644 --- a/pkg/os/shell/shell_windows_test.go +++ b/pkg/os/shell/shell_windows_test.go @@ -1,51 +1,137 @@ package shell import ( - "math" - "os" "testing" "github.com/stretchr/testify/assert" ) -func TestDetect(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "") - - shell, err := detect() +func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "failure to get process details for given pid", + []MockedProcess{}, + "", + }, + { + "failure while getting name of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + nameGetFails: true, + }, + }, + "", + }, + { + "failure while getting ppid of process", + []MockedProcess{ + { + name: "crc.exe", + }, + { + parentGetFails: true, + }, + }, + "", + }, + { + "failure when no shell process in process tree", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "unknown.exe", + }, + }, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } - assert.Contains(t, supportedShell, shell) - assert.NoError(t, err) -} + // When + shell, err := detect() -func TestGetNameAndItsPpidOfCurrent(t *testing.T) { - pid := os.Getpid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") - } - shell, shellppid, err := getNameAndItsPpid(uint32(pid)) - assert.Equal(t, "shell.test.exe", shell) - ppid := os.Getppid() - if ppid < 0 || ppid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") + // Then + assert.NoError(t, err) + assert.Equal(t, "cmd", shell) + }) } - assert.Equal(t, uint32(ppid), shellppid) - assert.NoError(t, err) } -func TestGetNameAndItsPpidOfParent(t *testing.T) { - pid := os.Getppid() - if pid < 0 || pid > math.MaxUint32 { - assert.Fail(t, "integer overflow detected") +func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) { + tests := []struct { + name string + processTree []MockedProcess + expectedShellType string + }{ + { + "bash shell, then detect bash shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "bash.exe", + }, + }, + "bash", + }, + { + "powershell, then detect powershell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "powershell.exe", + }, + }, + "powershell", + }, + { + "cmd shell, then detect fish shell", + []MockedProcess{ + { + name: "crc.exe", + }, + { + name: "cmd.exe", + }, + }, + "cmd", + }, } - shell, _, err := getNameAndItsPpid(uint32(pid)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + currentProcessSupplier = func() AbstractProcess { + return createNewMockProcessTreeFrom(tt.processTree) + } + // When + shell, err := detect() - assert.Equal(t, "go.exe", shell) - assert.NoError(t, err) + // Then + assert.Equal(t, tt.expectedShellType, shell) + assert.NoError(t, err) + }) + } } func TestSupportedShells(t *testing.T) { - assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell) + assert.Equal(t, []string{"cmd", "cmd.exe", "powershell", "powershell.exe", "wsl.exe", "bash.exe", "bash", "zsh", "fish"}, supportedShell) } func TestShellType(t *testing.T) { diff --git a/pkg/os/util.go b/pkg/os/util.go index df097312c6..7b78426509 100644 --- a/pkg/os/util.go +++ b/pkg/os/util.go @@ -127,3 +127,12 @@ func RemoveFileGlob(glob string) error { } return nil } + +func IsPresentInList(arr []string, str string) bool { + for _, a := range arr { + if a == str { + return true + } + } + return false +} diff --git a/pkg/os/util_test.go b/pkg/os/util_test.go index baa8ab8315..f576f51d86 100644 --- a/pkg/os/util_test.go +++ b/pkg/os/util_test.go @@ -94,3 +94,25 @@ func TestFileExists(t *testing.T) { filename = filepath.Join(dirname, "nonexistent") assert.False(t, FileExists(filename)) } + +func TestIsPresentInList(t *testing.T) { + tests := []struct { + arrayList []string + item string + expected bool + }{ + {arrayList: []string{"apple", "banana", "cherry"}, item: "banana", expected: true}, + {arrayList: []string{"apple", "banana", "cherry"}, item: "grape", expected: false}, + {arrayList: []string{}, item: "apple", expected: false}, + {arrayList: []string{"apple", "apple", "apple"}, item: "apple", expected: true}, + } + + for _, tt := range tests { + t.Run(tt.item, func(t *testing.T) { + got := IsPresentInList(tt.arrayList, tt.item) + if got != tt.expected { + t.Errorf("IsPresentInList(%v, %q) = %v; expected %v", tt.arrayList, tt.item, got, tt.expected) + } + }) + } +}