Skip to content

Commit

Permalink
refactor (shell) : Windows Shell detection uses gopsutil (#4588)
Browse files Browse the repository at this point in the history
+ Update `shell_windows.go` to use detectShellByCheckingProcessTree
instead of relying on SHELL environment variable.
+ Remove hardcoded check from detectShellByCheckingProcessTree for shell
  types, use already present supportedShell slice.

Signed-off-by: Rohan Kumar <[email protected]>
  • Loading branch information
rohanKanojia authored and praveenkumar committed Feb 11, 2025
1 parent 644f23f commit 2499dd2
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 108 deletions.
5 changes: 4 additions & 1 deletion pkg/os/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os"
"slices"
"strings"

"github.com/shirou/gopsutil/v4/process"
Expand Down Expand Up @@ -208,7 +209,9 @@ func detectShellByCheckingProcessTree(p AbstractProcess) string {
if err != nil {
return ""
}
if processName == "zsh" || processName == "bash" || processName == "fish" {
if slices.ContainsFunc(supportedShell, func(listElem string) bool {
return strings.HasPrefix(processName, listElem)
}) {
return processName
}
p, err = p.Parent()
Expand Down
4 changes: 4 additions & 0 deletions pkg/os/shell/shell_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"path/filepath"
)

var (
supportedShell = []string{"bash", "zsh", "fish"}
)

// detect detects user's current shell.
func detect() (string, error) {
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())
Expand Down
84 changes: 7 additions & 77 deletions pkg/os/shell/shell_windows.go
Original file line number Diff line number Diff line change
@@ -1,64 +1,18 @@
package shell

import (
"fmt"
"math"
"os"
"path/filepath"
"slices"
"sort"
"strconv"
"strings"
"syscall"
"unsafe"

"github.com/crc-org/crc/v2/pkg/crc/logging"
)

var (
supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"}
supportedShell = []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}
)

// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go
func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) {
snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0)
if err != nil {
return nil, err
}
defer func() {
_ = syscall.CloseHandle(syscall.Handle(snapshot))
}()

var processEntry syscall.ProcessEntry32
processEntry.Size = uint32(unsafe.Sizeof(processEntry))
err = syscall.Process32First(snapshot, &processEntry)
if err != nil {
return nil, err
}

for {
if processEntry.ProcessID == pid {
pe = &processEntry
return
}

err = syscall.Process32Next(snapshot, &processEntry)
if err != nil {
return nil, err
}
}
}

// getNameAndItsPpid returns the exe file name its parent process id.
func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) {
pe, err := getProcessEntry(pid)
if err != nil {
return "", 0, err
}

name := syscall.UTF16ToString(pe.ExeFile[:])
return name, pe.ParentProcessID, nil
}

func shellType(shell string, defaultShell string) string {
switch {
case strings.Contains(strings.ToLower(shell), "powershell"):
Expand All @@ -69,39 +23,15 @@ func shellType(shell string, defaultShell string) string {
return "cmd"
case strings.Contains(strings.ToLower(shell), "wsl"):
return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="})
case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"):
case strings.Contains(strings.ToLower(shell), "bash"):
return "bash"
default:
return defaultShell
}
}

func detect() (string, error) {
shell := os.Getenv("SHELL")

if shell == "" {
pid := os.Getppid()
if pid < 0 || pid > math.MaxUint32 {
return "", fmt.Errorf("integer overflow for pid: %v", pid)
}
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
if err != nil {
return "cmd", err // defaulting to cmd
}
shell = shellType(shell, "")
if shell == "" {
shell, _, err := getNameAndItsPpid(shellppid)
if err != nil {
return "cmd", err // defaulting to cmd
}
return shellType(shell, "cmd"), nil
}
return shell, nil
}

if os.Getenv("__fish_bin_dir") != "" {
return "fish", nil
}
shell := detectShellByCheckingProcessTree(currentProcessSupplier())

return shellType(shell, "cmd"), nil
}
Expand Down Expand Up @@ -163,9 +93,9 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string {
lines := strings.Split(psCommandOutput, "\n")
for _, line := range lines {
lineParts := strings.Split(strings.TrimSpace(line), " ")
if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") ||
strings.Contains(lineParts[1], "bash") ||
strings.Contains(lineParts[1], "fish")) {
if len(lineParts) == 2 && slices.ContainsFunc(supportedShell, func(listElem string) bool {
return strings.HasPrefix(lineParts[1], listElem)
}) {
parsedProcessID, err := strconv.Atoi(lineParts[0])
if err == nil {
processOutputs = append(processOutputs, ProcessOutput{
Expand Down
146 changes: 116 additions & 30 deletions pkg/os/shell/shell_windows_test.go
Original file line number Diff line number Diff line change
@@ -1,51 +1,137 @@
package shell

import (
"math"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestDetect(t *testing.T) {
defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL"))
os.Setenv("SHELL", "")

shell, err := detect()
func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) {
tests := []struct {
name string
processTree []MockedProcess
expectedShellType string
}{
{
"failure to get process details for given pid",
[]MockedProcess{},
"",
},
{
"failure while getting name of process",
[]MockedProcess{
{
name: "crc.exe",
},
{
nameGetFails: true,
},
},
"",
},
{
"failure while getting ppid of process",
[]MockedProcess{
{
name: "crc.exe",
},
{
parentGetFails: true,
},
},
"",
},
{
"failure when no shell process in process tree",
[]MockedProcess{
{
name: "crc.exe",
},
{
name: "unknown.exe",
},
},
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Given
currentProcessSupplier = func() AbstractProcess {
return createNewMockProcessTreeFrom(tt.processTree)
}

assert.Contains(t, supportedShell, shell)
assert.NoError(t, err)
}
// When
shell, err := detect()

func TestGetNameAndItsPpidOfCurrent(t *testing.T) {
pid := os.Getpid()
if pid < 0 || pid > math.MaxUint32 {
assert.Fail(t, "integer overflow detected")
}
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
assert.Equal(t, "shell.test.exe", shell)
ppid := os.Getppid()
if ppid < 0 || ppid > math.MaxUint32 {
assert.Fail(t, "integer overflow detected")
// Then
assert.NoError(t, err)
assert.Equal(t, "cmd", shell)
})
}
assert.Equal(t, uint32(ppid), shellppid)
assert.NoError(t, err)
}

func TestGetNameAndItsPpidOfParent(t *testing.T) {
pid := os.Getppid()
if pid < 0 || pid > math.MaxUint32 {
assert.Fail(t, "integer overflow detected")
func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) {
tests := []struct {
name string
processTree []MockedProcess
expectedShellType string
}{
{
"bash shell, then detect bash shell",
[]MockedProcess{
{
name: "crc.exe",
},
{
name: "bash.exe",
},
},
"bash",
},
{
"powershell, then detect powershell",
[]MockedProcess{
{
name: "crc.exe",
},
{
name: "powershell.exe",
},
},
"powershell",
},
{
"cmd shell, then detect fish shell",
[]MockedProcess{
{
name: "crc.exe",
},
{
name: "cmd.exe",
},
},
"cmd",
},
}
shell, _, err := getNameAndItsPpid(uint32(pid))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Given
currentProcessSupplier = func() AbstractProcess {
return createNewMockProcessTreeFrom(tt.processTree)
}
// When
shell, err := detect()

assert.Equal(t, "go.exe", shell)
assert.NoError(t, err)
// Then
assert.Equal(t, tt.expectedShellType, shell)
assert.NoError(t, err)
})
}
}

func TestSupportedShells(t *testing.T) {
assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell)
assert.Equal(t, []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}, supportedShell)
}

func TestShellType(t *testing.T) {
Expand Down

0 comments on commit 2499dd2

Please sign in to comment.