mirror of
https://github.com/sst/opencode.git
synced 2025-08-07 23:08:11 +00:00

Added nil check in GetPersistentShell before accessing shellInstance.isAlive to prevent panic when newPersistentShell returns nil due to shell startup errors. This resolves the "invalid memory address or nil pointer dereference" error that was occurring in the shell tool.
304 lines
6 KiB
Go
304 lines
6 KiB
Go
package shell
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type PersistentShell struct {
|
|
cmd *exec.Cmd
|
|
stdin *os.File
|
|
isAlive bool
|
|
cwd string
|
|
mu sync.Mutex
|
|
commandQueue chan *commandExecution
|
|
}
|
|
|
|
type commandExecution struct {
|
|
command string
|
|
timeout time.Duration
|
|
resultChan chan commandResult
|
|
ctx context.Context
|
|
}
|
|
|
|
type commandResult struct {
|
|
stdout string
|
|
stderr string
|
|
exitCode int
|
|
interrupted bool
|
|
err error
|
|
}
|
|
|
|
var (
|
|
shellInstance *PersistentShell
|
|
shellInstanceOnce sync.Once
|
|
)
|
|
|
|
func GetPersistentShell(workingDir string) *PersistentShell {
|
|
shellInstanceOnce.Do(func() {
|
|
shellInstance = newPersistentShell(workingDir)
|
|
})
|
|
|
|
if shellInstance == nil || !shellInstance.isAlive {
|
|
shellInstance = newPersistentShell(shellInstance.cwd)
|
|
}
|
|
|
|
return shellInstance
|
|
}
|
|
|
|
func newPersistentShell(cwd string) *PersistentShell {
|
|
shellPath := os.Getenv("SHELL")
|
|
if shellPath == "" {
|
|
shellPath = "/bin/bash"
|
|
}
|
|
|
|
cmd := exec.Command(shellPath, "-l")
|
|
cmd.Dir = cwd
|
|
|
|
stdinPipe, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
|
|
|
|
err = cmd.Start()
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
shell := &PersistentShell{
|
|
cmd: cmd,
|
|
stdin: stdinPipe.(*os.File),
|
|
isAlive: true,
|
|
cwd: cwd,
|
|
commandQueue: make(chan *commandExecution, 10),
|
|
}
|
|
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
|
|
shell.isAlive = false
|
|
close(shell.commandQueue)
|
|
}
|
|
}()
|
|
shell.processCommands()
|
|
}()
|
|
|
|
go func() {
|
|
err := cmd.Wait()
|
|
if err != nil {
|
|
// Log the error if needed
|
|
}
|
|
shell.isAlive = false
|
|
close(shell.commandQueue)
|
|
}()
|
|
|
|
return shell
|
|
}
|
|
|
|
func (s *PersistentShell) processCommands() {
|
|
for cmd := range s.commandQueue {
|
|
result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
|
|
cmd.resultChan <- result
|
|
}
|
|
}
|
|
|
|
func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if !s.isAlive {
|
|
return commandResult{
|
|
stderr: "Shell is not alive",
|
|
exitCode: 1,
|
|
err: errors.New("shell is not alive"),
|
|
}
|
|
}
|
|
|
|
tempDir := os.TempDir()
|
|
stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
|
|
stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
|
|
statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
|
|
cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
|
|
|
|
defer func() {
|
|
os.Remove(stdoutFile)
|
|
os.Remove(stderrFile)
|
|
os.Remove(statusFile)
|
|
os.Remove(cwdFile)
|
|
}()
|
|
|
|
fullCommand := fmt.Sprintf(`
|
|
eval %s < /dev/null > %s 2> %s
|
|
EXEC_EXIT_CODE=$?
|
|
pwd > %s
|
|
echo $EXEC_EXIT_CODE > %s
|
|
`,
|
|
shellQuote(command),
|
|
shellQuote(stdoutFile),
|
|
shellQuote(stderrFile),
|
|
shellQuote(cwdFile),
|
|
shellQuote(statusFile),
|
|
)
|
|
|
|
_, err := s.stdin.Write([]byte(fullCommand + "\n"))
|
|
if err != nil {
|
|
return commandResult{
|
|
stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
|
|
exitCode: 1,
|
|
err: err,
|
|
}
|
|
}
|
|
|
|
interrupted := false
|
|
|
|
startTime := time.Now()
|
|
|
|
done := make(chan bool)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
s.killChildren()
|
|
interrupted = true
|
|
done <- true
|
|
return
|
|
|
|
case <-time.After(10 * time.Millisecond):
|
|
if fileExists(statusFile) && fileSize(statusFile) > 0 {
|
|
done <- true
|
|
return
|
|
}
|
|
|
|
if timeout > 0 {
|
|
elapsed := time.Since(startTime)
|
|
if elapsed > timeout {
|
|
s.killChildren()
|
|
interrupted = true
|
|
done <- true
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
<-done
|
|
|
|
stdout := readFileOrEmpty(stdoutFile)
|
|
stderr := readFileOrEmpty(stderrFile)
|
|
exitCodeStr := readFileOrEmpty(statusFile)
|
|
newCwd := readFileOrEmpty(cwdFile)
|
|
|
|
exitCode := 0
|
|
if exitCodeStr != "" {
|
|
fmt.Sscanf(exitCodeStr, "%d", &exitCode)
|
|
} else if interrupted {
|
|
exitCode = 143
|
|
stderr += "\nCommand execution timed out or was interrupted"
|
|
}
|
|
|
|
if newCwd != "" {
|
|
s.cwd = strings.TrimSpace(newCwd)
|
|
}
|
|
|
|
return commandResult{
|
|
stdout: stdout,
|
|
stderr: stderr,
|
|
exitCode: exitCode,
|
|
interrupted: interrupted,
|
|
}
|
|
}
|
|
|
|
func (s *PersistentShell) killChildren() {
|
|
if s.cmd == nil || s.cmd.Process == nil {
|
|
return
|
|
}
|
|
|
|
pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
|
|
output, err := pgrepCmd.Output()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
for pidStr := range strings.SplitSeq(string(output), "\n") {
|
|
if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
|
|
var pid int
|
|
fmt.Sscanf(pidStr, "%d", &pid)
|
|
if pid > 0 {
|
|
proc, err := os.FindProcess(pid)
|
|
if err == nil {
|
|
proc.Signal(syscall.SIGTERM)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
|
|
if !s.isAlive {
|
|
return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
|
|
}
|
|
|
|
timeout := time.Duration(timeoutMs) * time.Millisecond
|
|
|
|
resultChan := make(chan commandResult)
|
|
s.commandQueue <- &commandExecution{
|
|
command: command,
|
|
timeout: timeout,
|
|
resultChan: resultChan,
|
|
ctx: ctx,
|
|
}
|
|
|
|
result := <-resultChan
|
|
return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
|
|
}
|
|
|
|
func (s *PersistentShell) Close() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if !s.isAlive {
|
|
return
|
|
}
|
|
|
|
s.stdin.Write([]byte("exit\n"))
|
|
|
|
s.cmd.Process.Kill()
|
|
s.isAlive = false
|
|
}
|
|
|
|
func shellQuote(s string) string {
|
|
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
|
}
|
|
|
|
func readFileOrEmpty(path string) string {
|
|
content, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return string(content)
|
|
}
|
|
|
|
func fileExists(path string) bool {
|
|
_, err := os.Stat(path)
|
|
return err == nil
|
|
}
|
|
|
|
func fileSize(path string) int64 {
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return info.Size()
|
|
}
|