Skip to content

Commit

Permalink
refactor (shell) : detect shell on windows using gopsutil too (#4588)
Browse files Browse the repository at this point in the history
As a follow up to #4572, updating shell detection logic on windows to also
rely on gopsutil library to detect the shell name by inspecting parent processes
of `crc.exe` process.

Signed-off-by: Rohan Kumar <[email protected]>
  • Loading branch information
rohanKanojia committed Feb 3, 2025
1 parent 6d72284 commit 9f7677b
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 221 deletions.
67 changes: 67 additions & 0 deletions pkg/os/shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package shell

import (
"errors"
"fmt"
"os"
"strings"

"github.com/shirou/gopsutil/v4/process"
"github.com/spf13/cast"

crcos "github.com/crc-org/crc/v2/pkg/os"
)

var (
CommandRunner = crcos.NewLocalCommandRunner()
WindowsSubsystemLinuxKernelMetadataFile = "/proc/version"
ErrUnknownShell = errors.New("Error: Unknown shell")
currentProcessSupplier = createCurrentProcess
)

type Config struct {
Expand All @@ -20,6 +26,20 @@ type Config struct {
PathSuffix string
}

// AbstractProcess is an interface created to abstract operations of the gopsutil library
// It is created so that we can override the behavior while writing unit tests by providing
// a mock implementation.
type AbstractProcess interface {
Name() (string, error)
Parent() (AbstractProcess, error)
}

// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
// process.Process object. This implementation is used in production code.
type RealProcess struct {
*process.Process
}

func GetShell(userShell string) (string, error) {
if userShell != "" {
if !isSupportedShell(userShell) {
Expand Down Expand Up @@ -151,3 +171,50 @@ func IsWindowsSubsystemLinux() bool {
}
return false
}

func (p *RealProcess) Parent() (AbstractProcess, error) {
parentProcess, err := p.Process.Parent()
if err != nil {
return nil, err
}
return &RealProcess{parentProcess}, nil
}

func createCurrentProcess() AbstractProcess {
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
if err != nil {
return nil
}
return &RealProcess{currentProcess}
}

// detectShellByCheckingProcessTree attempts to identify the shell being used by
// examining the process tree starting from the given process ID. This function
// traverses up to ProcessDepthLimit levels up the process hierarchy.
// Parameters:
// - pid (int): The process ID to start checking from.
//
// Returns:
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
// otherwise, an empty string is returned if no matching shell is detected
// or an error occurs during the process tree traversal.
//
// Examples:
//
// shellName := detectShellByCheckingProcessTree(1234)
func detectShellByCheckingProcessTree(p AbstractProcess) string {
for p != nil {
processName, err := p.Name()
if err != nil {
return ""
}
if crcos.IsPresentInList(supportedShell, processName) {
return processName
}
p, err = p.Parent()
if err != nil {
return ""
}
}
return ""
}
5 changes: 0 additions & 5 deletions pkg/os/shell/shell_darwin.go

This file was deleted.

5 changes: 0 additions & 5 deletions pkg/os/shell/shell_linux.go

This file was deleted.

36 changes: 36 additions & 0 deletions pkg/os/shell/shell_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package shell

import (
"errors"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -47,6 +48,28 @@ func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (strin
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
}

// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
type MockedProcess struct {
name string
parent *MockedProcess
nameGetFails bool
parentGetFails bool
}

func (m MockedProcess) Parent() (AbstractProcess, error) {
if m.parentGetFails || m.parent == nil {
return nil, errors.New("failed to get the pid")
}
return m.parent, nil
}

func (m MockedProcess) Name() (string, error) {
if m.nameGetFails {
return "", errors.New("failed to get the name")
}
return m.name, nil
}

func TestGetPathEnvString(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -179,3 +202,16 @@ func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) {
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
}

func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
if len(processes) == 0 {
return nil
}
head := &processes[0]
current := head
for i := 1; i < len(processes); i++ {
current.parent = &processes[i]
current = current.parent
}
return head
}
69 changes: 1 addition & 68 deletions pkg/os/shell/shell_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,14 @@
package shell

import (
"errors"
"fmt"
"os"
"path/filepath"

"github.com/shirou/gopsutil/v4/process"
"github.com/spf13/cast"
)

var (
ErrUnknownShell = errors.New("Error: Unknown shell")
currentProcessSupplier = createCurrentProcess
supportedShell = []string{"bash", "zsh", "fish"}
)

// AbstractProcess is an interface created to abstract operations of the gopsutil library
// It is created so that we can override the behavior while writing unit tests by providing
// a mock implementation.
type AbstractProcess interface {
Name() (string, error)
Parent() (AbstractProcess, error)
}

// RealProcess is a wrapper implementation of AbstractProcess to wrap around the gopsutil library's
// process.Process object. This implementation is used in production code.
type RealProcess struct {
*process.Process
}

func (p *RealProcess) Parent() (AbstractProcess, error) {
parentProcess, err := p.Process.Parent()
if err != nil {
return nil, err
}
return &RealProcess{parentProcess}, nil
}

func createCurrentProcess() AbstractProcess {
currentProcess, err := process.NewProcess(cast.ToInt32(os.Getpid()))
if err != nil {
return nil
}
return &RealProcess{currentProcess}
}

// detect detects user's current shell.
func detect() (string, error) {
detectedShell := detectShellByCheckingProcessTree(currentProcessSupplier())
Expand All @@ -58,34 +22,3 @@ func detect() (string, error) {

return filepath.Base(detectedShell), nil
}

// detectShellByCheckingProcessTree attempts to identify the shell being used by
// examining the process tree starting from the given process ID. This function
// traverses up to ProcessDepthLimit levels up the process hierarchy.
// Parameters:
// - pid (int): The process ID to start checking from.
//
// Returns:
// - string: The name of the shell if found (e.g., "zsh", "bash", "fish");
// otherwise, an empty string is returned if no matching shell is detected
// or an error occurs during the process tree traversal.
//
// Examples:
//
// shellName := detectShellByCheckingProcessTree(1234)
func detectShellByCheckingProcessTree(p AbstractProcess) string {
for p != nil {
processName, err := p.Name()
if err != nil {
return ""
}
if processName == "zsh" || processName == "bash" || processName == "fish" {
return processName
}
p, err = p.Parent()
if err != nil {
return ""
}
}
return ""
}
36 changes: 0 additions & 36 deletions pkg/os/shell/shell_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,12 @@ package shell

import (
"bytes"
"errors"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

// MockedProcess is a mock implementation of AbstractProcess for testing purposes.
type MockedProcess struct {
name string
parent *MockedProcess
nameGetFails bool
parentGetFails bool
}

func (m MockedProcess) Parent() (AbstractProcess, error) {
if m.parentGetFails || m.parent == nil {
return nil, errors.New("failed to get the pid")
}
return m.parent, nil
}

func (m MockedProcess) Name() (string, error) {
if m.nameGetFails {
return "", errors.New("failed to get the name")
}
return m.name, nil
}

func TestUnknownShell(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -183,16 +160,3 @@ func TestGetCurrentProcess(t *testing.T) {
assert.NoError(t, err)
assert.Greater(t, len(currentProcessName), 0)
}

func createNewMockProcessTreeFrom(processes []MockedProcess) AbstractProcess {
if len(processes) == 0 {
return nil
}
head := &processes[0]
current := head
for i := 1; i < len(processes); i++ {
current.parent = &processes[i]
current = current.parent
}
return head
}
Loading

0 comments on commit 9f7677b

Please sign in to comment.