Skip to content

Commit 9e42792

Browse files
committed
refactor (shell) : Windows Shell detection uses gopsutil (#4588)
+ Update `shell_windows.go` to use detectShellByCheckingProcessTree instead of relying on SHELL environment variable. + Remove hardcoded check from detectShellByCheckingProcessTree for shell types, use already present supportedShell slice. + Add a utility method in strings for verifying a slice contains an element matching the provided predicate. Signed-off-by: Rohan Kumar <[email protected]>
1 parent 8959c8e commit 9e42792

File tree

6 files changed

+189
-109
lines changed

6 files changed

+189
-109
lines changed

pkg/os/shell/shell.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"os"
77
"strings"
88

9+
crcstrings "github.com/crc-org/crc/v2/pkg/strings"
10+
911
"github.com/shirou/gopsutil/v4/process"
1012
"github.com/spf13/cast"
1113

@@ -208,7 +210,9 @@ func detectShellByCheckingProcessTree(p AbstractProcess) string {
208210
if err != nil {
209211
return ""
210212
}
211-
if processName == "zsh" || processName == "bash" || processName == "fish" {
213+
if crcstrings.IsPresentInListSatisfying(supportedShell, processName, func(listElem string, toMatch string) bool {
214+
return strings.HasPrefix(toMatch, listElem)
215+
}) {
212216
return processName
213217
}
214218
p, err = p.Parent()

pkg/os/shell/shell_unix.go

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import (
88
"path/filepath"
99
)
1010

11+
var (
12+
supportedShell = []string{"bash", "zsh", "fish"}
13+
)
14+
1115
// detect detects user's current shell.
1216
func detect() (string, error) {
1317
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())

pkg/os/shell/shell_windows.go

+8-77
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,19 @@
11
package shell
22

33
import (
4-
"fmt"
5-
"math"
6-
"os"
7-
"path/filepath"
84
"sort"
95
"strconv"
106
"strings"
11-
"syscall"
12-
"unsafe"
7+
8+
crcstrings "github.com/crc-org/crc/v2/pkg/strings"
139

1410
"github.com/crc-org/crc/v2/pkg/crc/logging"
1511
)
1612

1713
var (
18-
supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"}
14+
supportedShell = []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}
1915
)
2016

21-
// re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go
22-
func getProcessEntry(pid uint32) (pe *syscall.ProcessEntry32, err error) {
23-
snapshot, err := syscall.CreateToolhelp32Snapshot(syscall.TH32CS_SNAPPROCESS, 0)
24-
if err != nil {
25-
return nil, err
26-
}
27-
defer func() {
28-
_ = syscall.CloseHandle(syscall.Handle(snapshot))
29-
}()
30-
31-
var processEntry syscall.ProcessEntry32
32-
processEntry.Size = uint32(unsafe.Sizeof(processEntry))
33-
err = syscall.Process32First(snapshot, &processEntry)
34-
if err != nil {
35-
return nil, err
36-
}
37-
38-
for {
39-
if processEntry.ProcessID == pid {
40-
pe = &processEntry
41-
return
42-
}
43-
44-
err = syscall.Process32Next(snapshot, &processEntry)
45-
if err != nil {
46-
return nil, err
47-
}
48-
}
49-
}
50-
51-
// getNameAndItsPpid returns the exe file name its parent process id.
52-
func getNameAndItsPpid(pid uint32) (exefile string, parentid uint32, err error) {
53-
pe, err := getProcessEntry(pid)
54-
if err != nil {
55-
return "", 0, err
56-
}
57-
58-
name := syscall.UTF16ToString(pe.ExeFile[:])
59-
return name, pe.ParentProcessID, nil
60-
}
61-
6217
func shellType(shell string, defaultShell string) string {
6318
switch {
6419
case strings.Contains(strings.ToLower(shell), "powershell"):
@@ -69,39 +24,15 @@ func shellType(shell string, defaultShell string) string {
6924
return "cmd"
7025
case strings.Contains(strings.ToLower(shell), "wsl"):
7126
return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid=,comm="})
72-
case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"):
27+
case strings.Contains(strings.ToLower(shell), "bash"):
7328
return "bash"
7429
default:
7530
return defaultShell
7631
}
7732
}
7833

7934
func detect() (string, error) {
80-
shell := os.Getenv("SHELL")
81-
82-
if shell == "" {
83-
pid := os.Getppid()
84-
if pid < 0 || pid > math.MaxUint32 {
85-
return "", fmt.Errorf("integer overflow for pid: %v", pid)
86-
}
87-
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
88-
if err != nil {
89-
return "cmd", err // defaulting to cmd
90-
}
91-
shell = shellType(shell, "")
92-
if shell == "" {
93-
shell, _, err := getNameAndItsPpid(shellppid)
94-
if err != nil {
95-
return "cmd", err // defaulting to cmd
96-
}
97-
return shellType(shell, "cmd"), nil
98-
}
99-
return shell, nil
100-
}
101-
102-
if os.Getenv("__fish_bin_dir") != "" {
103-
return "fish", nil
104-
}
35+
shell := detectShellByCheckingProcessTree(currentProcessSupplier())
10536

10637
return shellType(shell, "cmd"), nil
10738
}
@@ -163,9 +94,9 @@ func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string {
16394
lines := strings.Split(psCommandOutput, "\n")
16495
for _, line := range lines {
16596
lineParts := strings.Split(strings.TrimSpace(line), " ")
166-
if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") ||
167-
strings.Contains(lineParts[1], "bash") ||
168-
strings.Contains(lineParts[1], "fish")) {
97+
if len(lineParts) == 2 && crcstrings.IsPresentInListSatisfying(supportedShell, lineParts[1], func(listElem string, toMatch string) bool {
98+
return strings.HasPrefix(toMatch, listElem)
99+
}) {
169100
parsedProcessID, err := strconv.Atoi(lineParts[0])
170101
if err == nil {
171102
processOutputs = append(processOutputs, ProcessOutput{

pkg/os/shell/shell_windows_test.go

+116-30
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,137 @@
11
package shell
22

33
import (
4-
"math"
5-
"os"
64
"testing"
75

86
"github.com/stretchr/testify/assert"
97
)
108

11-
func TestDetect(t *testing.T) {
12-
defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL"))
13-
os.Setenv("SHELL", "")
14-
15-
shell, err := detect()
9+
func TestDetect_WhenUnknownShell_ThenDefaultToCmdShell(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
processTree []MockedProcess
13+
expectedShellType string
14+
}{
15+
{
16+
"failure to get process details for given pid",
17+
[]MockedProcess{},
18+
"",
19+
},
20+
{
21+
"failure while getting name of process",
22+
[]MockedProcess{
23+
{
24+
name: "crc.exe",
25+
},
26+
{
27+
nameGetFails: true,
28+
},
29+
},
30+
"",
31+
},
32+
{
33+
"failure while getting ppid of process",
34+
[]MockedProcess{
35+
{
36+
name: "crc.exe",
37+
},
38+
{
39+
parentGetFails: true,
40+
},
41+
},
42+
"",
43+
},
44+
{
45+
"failure when no shell process in process tree",
46+
[]MockedProcess{
47+
{
48+
name: "crc.exe",
49+
},
50+
{
51+
name: "unknown.exe",
52+
},
53+
},
54+
"",
55+
},
56+
}
57+
for _, tt := range tests {
58+
t.Run(tt.name, func(t *testing.T) {
59+
// Given
60+
currentProcessSupplier = func() AbstractProcess {
61+
return createNewMockProcessTreeFrom(tt.processTree)
62+
}
1663

17-
assert.Contains(t, supportedShell, shell)
18-
assert.NoError(t, err)
19-
}
64+
// When
65+
shell, err := detect()
2066

21-
func TestGetNameAndItsPpidOfCurrent(t *testing.T) {
22-
pid := os.Getpid()
23-
if pid < 0 || pid > math.MaxUint32 {
24-
assert.Fail(t, "integer overflow detected")
25-
}
26-
shell, shellppid, err := getNameAndItsPpid(uint32(pid))
27-
assert.Equal(t, "shell.test.exe", shell)
28-
ppid := os.Getppid()
29-
if ppid < 0 || ppid > math.MaxUint32 {
30-
assert.Fail(t, "integer overflow detected")
67+
// Then
68+
assert.NoError(t, err)
69+
assert.Equal(t, "cmd", shell)
70+
})
3171
}
32-
assert.Equal(t, uint32(ppid), shellppid)
33-
assert.NoError(t, err)
3472
}
3573

36-
func TestGetNameAndItsPpidOfParent(t *testing.T) {
37-
pid := os.Getppid()
38-
if pid < 0 || pid > math.MaxUint32 {
39-
assert.Fail(t, "integer overflow detected")
74+
func TestDetect_GivenProcessTree_ThenReturnShellProcessWithCorrespondingParentPID(t *testing.T) {
75+
tests := []struct {
76+
name string
77+
processTree []MockedProcess
78+
expectedShellType string
79+
}{
80+
{
81+
"bash shell, then detect bash shell",
82+
[]MockedProcess{
83+
{
84+
name: "crc.exe",
85+
},
86+
{
87+
name: "bash.exe",
88+
},
89+
},
90+
"bash",
91+
},
92+
{
93+
"powershell, then detect powershell",
94+
[]MockedProcess{
95+
{
96+
name: "crc.exe",
97+
},
98+
{
99+
name: "powershell.exe",
100+
},
101+
},
102+
"powershell",
103+
},
104+
{
105+
"cmd shell, then detect fish shell",
106+
[]MockedProcess{
107+
{
108+
name: "crc.exe",
109+
},
110+
{
111+
name: "cmd.exe",
112+
},
113+
},
114+
"cmd",
115+
},
40116
}
41-
shell, _, err := getNameAndItsPpid(uint32(pid))
117+
for _, tt := range tests {
118+
t.Run(tt.name, func(t *testing.T) {
119+
// Given
120+
currentProcessSupplier = func() AbstractProcess {
121+
return createNewMockProcessTreeFrom(tt.processTree)
122+
}
123+
// When
124+
shell, err := detect()
42125

43-
assert.Equal(t, "go.exe", shell)
44-
assert.NoError(t, err)
126+
// Then
127+
assert.Equal(t, tt.expectedShellType, shell)
128+
assert.NoError(t, err)
129+
})
130+
}
45131
}
46132

47133
func TestSupportedShells(t *testing.T) {
48-
assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell)
134+
assert.Equal(t, []string{"cmd", "powershell", "wsl", "bash", "zsh", "fish"}, supportedShell)
49135
}
50136

51137
func TestShellType(t *testing.T) {

pkg/strings/strings.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ import (
66
)
77

88
func Contains(input []string, match string) bool {
9+
return IsPresentInListSatisfying(input, match, func(listElement string, toMatch string) bool {
10+
return listElement == toMatch
11+
})
12+
}
13+
14+
func IsPresentInListSatisfying(input []string, toMatch string, matchingPredicate func(string, string) bool) bool {
915
for _, v := range input {
10-
if v == match {
16+
if matchingPredicate(v, toMatch) {
1117
return true
1218
}
1319
}

0 commit comments

Comments
 (0)