Skip to content

Commit 9f7677b

Browse files
committed
refactor (shell) : detect shell on windows using gopsutil too (#4588)
As a follow up to #4572, updating shell detection logic on windows to also rely on gopsutil library to detect the shell name by inspecting parent processes of `crc.exe` process. Signed-off-by: Rohan Kumar <[email protected]>
1 parent 6d72284 commit 9f7677b

10 files changed

+257
-221
lines changed

pkg/os/shell/shell.go

+67
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
package shell
22

33
import (
4+
"errors"
45
"fmt"
56
"os"
67
"strings"
78

9+
"github.com/shirou/gopsutil/v4/process"
10+
"github.com/spf13/cast"
11+
812
crcos "github.com/crc-org/crc/v2/pkg/os"
913
)
1014

1115
var (
1216
CommandRunner = crcos.NewLocalCommandRunner()
1317
WindowsSubsystemLinuxKernelMetadataFile = "/proc/version"
18+
ErrUnknownShell = errors.New("Error: Unknown shell")
19+
currentProcessSupplier = createCurrentProcess
1420
)
1521

1622
type Config struct {
@@ -20,6 +26,20 @@ type Config struct {
2026
PathSuffix string
2127
}
2228

29+
// AbstractProcess is an interface created to abstract operations of the gopsutil library
30+
// It is created so that we can override the behavior while writing unit tests by providing
31+
// a mock implementation.
32+
type AbstractProcess interface {
33+
Name() (string, error)
34+
Parent() (AbstractProcess, error)
35+
}
36+
37+
// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
38+
// process.Process object. This implementation is used in production code.
39+
type RealProcess struct {
40+
*process.Process
41+
}
42+
2343
func GetShell(userShell string) (string, error) {
2444
if userShell != "" {
2545
if !isSupportedShell(userShell) {
@@ -151,3 +171,50 @@ func IsWindowsSubsystemLinux() bool {
151171
}
152172
return false
153173
}
174+
175+
func (p *RealProcess) Parent() (AbstractProcess, error) {
176+
parentProcess, err := p.Process.Parent()
177+
if err != nil {
178+
return nil, err
179+
}
180+
return &RealProcess{parentProcess}, nil
181+
}
182+
183+
func createCurrentProcess() AbstractProcess {
184+
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
185+
if err != nil {
186+
return nil
187+
}
188+
return &RealProcess{currentProcess}
189+
}
190+
191+
// detectShellByCheckingProcessTree attempts to identify the shell being used by
192+
// examining the process tree starting from the given process ID. This function
193+
// traverses up to ProcessDepthLimit levels up the process hierarchy.
194+
// Parameters:
195+
// - pid (int): The process ID to start checking from.
196+
//
197+
// Returns:
198+
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
199+
// otherwise, an empty string is returned if no matching shell is detected
200+
// or an error occurs during the process tree traversal.
201+
//
202+
// Examples:
203+
//
204+
// shellName := detectShellByCheckingProcessTree(1234)
205+
func detectShellByCheckingProcessTree(p AbstractProcess) string {
206+
for p != nil {
207+
processName, err := p.Name()
208+
if err != nil {
209+
return ""
210+
}
211+
if crcos.IsPresentInList(supportedShell, processName) {
212+
return processName
213+
}
214+
p, err = p.Parent()
215+
if err != nil {
216+
return ""
217+
}
218+
}
219+
return ""
220+
}

pkg/os/shell/shell_darwin.go

-5
This file was deleted.

pkg/os/shell/shell_linux.go

-5
This file was deleted.

pkg/os/shell/shell_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package shell
22

33
import (
4+
"errors"
45
"os"
56
"path/filepath"
67
"testing"
@@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin
4748
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
4849
}
4950

51+
// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
52+
type MockedProcess struct {
53+
name string
54+
parent *MockedProcess
55+
nameGetFails bool
56+
parentGetFails bool
57+
}
58+
59+
func (m MockedProcess) Parent() (AbstractProcess, error) {
60+
if m.parentGetFails || m.parent == nil {
61+
return nil, errors.New("failed to get the pid")
62+
}
63+
return m.parent, nil
64+
}
65+
66+
func (m MockedProcess) Name() (string, error) {
67+
if m.nameGetFails {
68+
return "", errors.New("failed to get the name")
69+
}
70+
return m.name, nil
71+
}
72+
5073
func TestGetPathEnvString(t *testing.T) {
5174
tests := []struct {
5275
name string
@@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) {
179202
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
180203
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
181204
}
205+
206+
func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
207+
if len(processes) == 0 {
208+
return nil
209+
}
210+
head := &processes[0]
211+
current := head
212+
for i := 1; i < len(processes); i++ {
213+
current.parent = &processes[i]
214+
current = current.parent
215+
}
216+
return head
217+
}

pkg/os/shell/shell_unix.go

+1-68
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,14 @@
44
package shell
55

66
import (
7-
"errors"
87
"fmt"
9-
"os"
108
"path/filepath"
11-
12-
"github.com/shirou/gopsutil/v4/process"
13-
"github.com/spf13/cast"
149
)
1510

1611
var (
17-
ErrUnknownShell = errors.New("Error: Unknown shell")
18-
currentProcessSupplier = createCurrentProcess
12+
supportedShell = []string{"bash", "zsh", "fish"}
1913
)
2014

21-
// AbstractProcess is an interface created to abstract operations of the gopsutil library
22-
// It is created so that we can override the behavior while writing unit tests by providing
23-
// a mock implementation.
24-
type AbstractProcess interface {
25-
Name() (string, error)
26-
Parent() (AbstractProcess, error)
27-
}
28-
29-
// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
30-
// process.Process object. This implementation is used in production code.
31-
type RealProcess struct {
32-
*process.Process
33-
}
34-
35-
func (p *RealProcess) Parent() (AbstractProcess, error) {
36-
parentProcess, err := p.Process.Parent()
37-
if err != nil {
38-
return nil, err
39-
}
40-
return &RealProcess{parentProcess}, nil
41-
}
42-
43-
func createCurrentProcess() AbstractProcess {
44-
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
45-
if err != nil {
46-
return nil
47-
}
48-
return &RealProcess{currentProcess}
49-
}
50-
5115
// detect detects user's current shell.
5216
func detect() (string, error) {
5317
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())
@@ -58,34 +22,3 @@ func detect() (string, error) {
5822

5923
return filepath.Base(detectedShell), nil
6024
}
61-
62-
// detectShellByCheckingProcessTree attempts to identify the shell being used by
63-
// examining the process tree starting from the given process ID. This function
64-
// traverses up to ProcessDepthLimit levels up the process hierarchy.
65-
// Parameters:
66-
// - pid (int): The process ID to start checking from.
67-
//
68-
// Returns:
69-
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
70-
// otherwise, an empty string is returned if no matching shell is detected
71-
// or an error occurs during the process tree traversal.
72-
//
73-
// Examples:
74-
//
75-
// shellName := detectShellByCheckingProcessTree(1234)
76-
func detectShellByCheckingProcessTree(p AbstractProcess) string {
77-
for p != nil {
78-
processName, err := p.Name()
79-
if err != nil {
80-
return ""
81-
}
82-
if processName == "zsh" || processName == "bash" || processName == "fish" {
83-
return processName
84-
}
85-
p, err = p.Parent()
86-
if err != nil {
87-
return ""
88-
}
89-
}
90-
return ""
91-
}

pkg/os/shell/shell_unix_test.go

-36
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,12 @@ package shell
55

66
import (
77
"bytes"
8-
"errors"
98
"os"
109
"testing"
1110

1211
"github.com/stretchr/testify/assert"
1312
)
1413

15-
// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
16-
type MockedProcess struct {
17-
name string
18-
parent *MockedProcess
19-
nameGetFails bool
20-
parentGetFails bool
21-
}
22-
23-
func (m MockedProcess) Parent() (AbstractProcess, error) {
24-
if m.parentGetFails || m.parent == nil {
25-
return nil, errors.New("failed to get the pid")
26-
}
27-
return m.parent, nil
28-
}
29-
30-
func (m MockedProcess) Name() (string, error) {
31-
if m.nameGetFails {
32-
return "", errors.New("failed to get the name")
33-
}
34-
return m.name, nil
35-
}
36-
3714
func TestUnknownShell(t *testing.T) {
3815
tests := []struct {
3916
name string
@@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) {
183160
assert.NoError(t, err)
184161
assert.Greater(t, len(currentProcessName), 0)
185162
}
186-
187-
func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
188-
if len(processes) == 0 {
189-
return nil
190-
}
191-
head := &processes[0]
192-
current := head
193-
for i := 1; i < len(processes); i++ {
194-
current.parent = &processes[i]
195-
current = current.parent
196-
}
197-
return head
198-
}

0 commit comments

Comments
 (0)