diff --git a/_fixtures/testnextprog.go b/_fixtures/testnextprog.go index a72d4c41..be03000f 100644 --- a/_fixtures/testnextprog.go +++ b/_fixtures/testnextprog.go @@ -7,7 +7,7 @@ import ( ) func sleepytime() { - time.Sleep(time.Nanosecond) + time.Sleep(100 * time.Millisecond) } func helloworld() { diff --git a/proctl/proctl_linux_amd64.go b/proctl/proctl_linux_amd64.go index 34d73067..ab73cfbe 100644 --- a/proctl/proctl_linux_amd64.go +++ b/proctl/proctl_linux_amd64.go @@ -3,9 +3,7 @@ package proctl import ( - "bytes" "debug/gosym" - "encoding/binary" "fmt" "os" "os/exec" @@ -40,6 +38,7 @@ type BreakPoint struct { Line int Addr uint64 OriginalData []byte + temp bool } type BreakPointExistsError struct { @@ -184,60 +183,14 @@ func (dbp *DebuggedProcess) addThread(tid int) (*ThreadContext, error) { return dbp.Threads[tid], nil } -// Sets a breakpoint in the running process. +// Sets a breakpoint in the current thread. func (dbp *DebuggedProcess) Break(addr uintptr) (*BreakPoint, error) { - var ( - int3 = []byte{0xCC} - f, l, fn = dbp.GoSymTable.PCToLine(uint64(addr)) - originalData = make([]byte, 1) - ) - - if fn == nil { - return nil, InvalidAddressError{address: addr} - } - - _, err := syscall.PtracePeekData(dbp.CurrentThread.Id, addr, originalData) - if err != nil { - return nil, err - } - - if bytes.Equal(originalData, int3) { - return nil, BreakPointExistsError{f, l, addr} - } - - _, err = syscall.PtracePokeData(dbp.CurrentThread.Id, addr, int3) - if err != nil { - return nil, err - } - - breakpoint := &BreakPoint{ - FunctionName: fn.Name, - File: f, - Line: l, - Addr: uint64(addr), - OriginalData: originalData, - } - - dbp.BreakPoints[uint64(addr)] = breakpoint - - return breakpoint, nil + return dbp.CurrentThread.Break(addr) } -// Clears a breakpoint. +// Clears a breakpoint in the current thread. func (dbp *DebuggedProcess) Clear(pc uint64) (*BreakPoint, error) { - bp, ok := dbp.BreakPoints[pc] - if !ok { - return nil, fmt.Errorf("No breakpoint currently set for %#v", pc) - } - - _, err := syscall.PtracePokeData(dbp.CurrentThread.Id, uintptr(bp.Addr), bp.OriginalData) - if err != nil { - return nil, err - } - - delete(dbp.BreakPoints, pc) - - return bp, nil + return dbp.CurrentThread.Clear(pc) } // Returns the status of the current main thread context. @@ -288,8 +241,10 @@ func (dbp *DebuggedProcess) Step() (err error) { func (dbp *DebuggedProcess) Next() error { for _, thread := range dbp.Threads { err := thread.Next() - if _, ok := err.(ProcessExitedError); !ok { - return err + if err != nil { + if _, ok := err.(ProcessExitedError); !ok { + return err + } } } @@ -305,7 +260,7 @@ func (dbp *DebuggedProcess) Continue() error { } } - _, _, err := wait(dbp, -1, 0) + _, _, err := trapWait(dbp, -1, 0) if err != nil { if _, ok := err.(ProcessExitedError); !ok { return err @@ -402,20 +357,6 @@ func (dbp *DebuggedProcess) obtainGoSymbols(wg *sync.WaitGroup) { dbp.GoSymTable = tab } -// Takes an offset from RSP and returns the address of the -// instruction the currect function is going to return to. -func (dbp *DebuggedProcess) ReturnAddressFromOffset(offset int64) uint64 { - regs, err := dbp.Registers() - if err != nil { - panic("Could not obtain register values") - } - - retaddr := int64(regs.Rsp) + offset - data := make([]byte, 8) - syscall.PtracePeekText(dbp.Pid, uintptr(retaddr), data) - return binary.LittleEndian.Uint64(data) -} - type ProcessExitedError struct { pid int } @@ -424,110 +365,82 @@ func (pe ProcessExitedError) Error() string { return fmt.Sprintf("process %d has exited", pe.pid) } -func wait(dbp *DebuggedProcess, p int, options int) (int, *syscall.WaitStatus, error) { +func trapWait(dbp *DebuggedProcess, p int, options int) (int, *syscall.WaitStatus, error) { + var status syscall.WaitStatus + for { - pid, status, err := timeoutWait(p, options) + pid, err := syscall.Wait4(p, &status, syscall.WALL|options, nil) if err != nil { - if _, ok := err.(TimeoutError); ok { - return p, nil, nil - } return -1, nil, fmt.Errorf("wait err %s %d", err, pid) } thread, threadtraced := dbp.Threads[pid] - if threadtraced { - thread.Status = status + if !threadtraced { + continue } + thread.Status = &status if status.Exited() { if pid == dbp.Pid { return 0, nil, ProcessExitedError{pid} } - delete(dbp.Threads, pid) + continue } - if status.StopSignal() == syscall.SIGTRAP { - if status.TrapCause() == syscall.PTRACE_EVENT_CLONE { - // A traced thread has cloned a new thread, grab the pid and - // add it to our list of traced threads. - msg, err := syscall.PtraceGetEventMsg(pid) - if err != nil { - return 0, nil, fmt.Errorf("could not get event message: %s", err) - } - - _, err = dbp.addThread(int(msg)) - if err != nil { - if _, ok := err.(ProcessExitedError); ok { - continue - } - return 0, nil, err - } - - err = syscall.PtraceCont(int(msg), 0) - if err != nil { - return 0, nil, fmt.Errorf("could not continue new thread %d %s", msg, err) - } - - err = syscall.PtraceCont(pid, 0) - if err != nil { - return 0, nil, fmt.Errorf("could not continue stopped thread %d %s", pid, err) - } - - continue + switch status.TrapCause() { + case syscall.PTRACE_EVENT_CLONE: + addNewThread(dbp, pid) + default: + pc, err := thread.CurrentPC() + if err != nil { + return -1, nil, fmt.Errorf("could not get current pc %s", err) } - - if pid != dbp.CurrentThread.Id { - fmt.Printf("changed thread context from %d to %d\n", dbp.CurrentThread.Id, pid) - dbp.CurrentThread = thread - } - - pc, _ := thread.CurrentPC() - // Check to see if we have hit a breakpoint - // that we know about. - if _, ok := dbp.BreakPoints[pc-1]; ok { - // Loop through all threads and ensure that we - // stop the rest of them, so that by the time - // we return control to the user, all threads - // are inactive. We send SIGSTOP and ensure all - // threads are in in signal-delivery-stop mode. - for _, th := range dbp.Threads { - if th.Id == pid { - // This thread is already stopped. - continue - } - - ps, err := parseProcessStatus(pid) - if err != nil { - return -1, nil, err - } - - if ps.state == STATUS_TRACE_STOP { - continue - } - - err = syscall.Tgkill(dbp.Pid, th.Id, syscall.SIGALRM) - if err != nil { - return -1, nil, err - } - - pid, err := syscall.Wait4(th.Id, nil, syscall.WALL, nil) - if err != nil { - return -1, nil, fmt.Errorf("wait err %s %d", err, pid) - } + // Check to see if we have hit a breakpoint. + if bp, ok := dbp.BreakPoints[pc-1]; ok { + if !bp.temp { + handleBreakPoint(dbp, thread, pid) } - return pid, status, nil + return pid, &status, nil } } if status.Stopped() { - if pid == dbp.Pid { - return pid, status, nil - } + // The thread has stopped, but has not hit a breakpoint. + // Continue the thread without returning control back + // to the console. + syscall.PtraceCont(pid, 0) } } } +func addNewThread(dbp *DebuggedProcess, pid int) error { + // A traced thread has cloned a new thread, grab the pid and + // add it to our list of traced threads. + msg, err := syscall.PtraceGetEventMsg(pid) + if err != nil { + return fmt.Errorf("could not get event message: %s", err) + } + fmt.Println("new thread spawned", msg) + + _, err = dbp.addThread(int(msg)) + if err != nil { + return err + } + + err = syscall.PtraceCont(int(msg), 0) + if err != nil { + return fmt.Errorf("could not continue new thread %d %s", msg, err) + } + + err = syscall.PtraceCont(pid, 0) + if err != nil { + return fmt.Errorf("could not continue stopped thread %d %s", pid, err) + } + + return nil +} + type waitstats struct { pid int status *syscall.WaitStatus @@ -550,6 +463,17 @@ func timeoutWait(pid int, options int) (int, *syscall.WaitStatus, error) { errchan = make(chan error) ) + if pid > 0 { + ps, err := parseProcessStatus(pid) + if err != nil { + return -1, nil, err + } + + if ps.state == STATUS_SLEEPING { + return 0, nil, nil + } + } + go func(pid int) { pid, err := syscall.Wait4(pid, &status, syscall.WALL|options, nil) if err != nil { @@ -562,16 +486,57 @@ func timeoutWait(pid int, options int) (int, *syscall.WaitStatus, error) { select { case s := <-statchan: return s.pid, s.status, nil - case <-time.After(1 * time.Second): - ps, err := parseProcessStatus(pid) - if err != nil { - return -1, nil, err + case <-time.After(2 * time.Second): + if pid > 0 { + ps, err := parseProcessStatus(pid) + if err != nil { + return -1, nil, err + } + syscall.Tgkill(ps.ppid, ps.pid, syscall.SIGSTOP) } - syscall.Tgkill(ps.ppid, ps.pid, syscall.SIGSTOP) - - return pid, nil, TimeoutError{pid} + return 0, nil, TimeoutError{pid} case err := <-errchan: return -1, nil, err } } + +func handleBreakPoint(dbp *DebuggedProcess, thread *ThreadContext, pid int) error { + if pid != dbp.CurrentThread.Id { + fmt.Printf("thread context changed from %d to %d\n", dbp.CurrentThread.Id, pid) + dbp.CurrentThread = thread + } + + // Loop through all threads and ensure that we + // stop the rest of them, so that by the time + // we return control to the user, all threads + // are inactive. We send SIGSTOP and ensure all + // threads are in in signal-delivery-stop mode. + for _, th := range dbp.Threads { + if th.Id == pid { + // This thread is already stopped. + continue + } + + ps, err := parseProcessStatus(th.Id) + if err != nil { + return err + } + + if ps.state == STATUS_TRACE_STOP { + continue + } + + err = syscall.Tgkill(dbp.Pid, th.Id, syscall.SIGSTOP) + if err != nil { + return err + } + + pid, err := syscall.Wait4(th.Id, nil, syscall.WALL, nil) + if err != nil { + return fmt.Errorf("wait err %s %d", err, pid) + } + } + + return nil +} diff --git a/proctl/proctl_test.go b/proctl/proctl_test.go index 118d194e..822e4b98 100644 --- a/proctl/proctl_test.go +++ b/proctl/proctl_test.go @@ -230,7 +230,7 @@ func TestNext(t *testing.T) { } if len(p.BreakPoints) != 1 { - t.Fatal("Not all breakpoints were cleaned up") + t.Fatal("Not all breakpoints were cleaned up", len(p.BreakPoints)) } }) } diff --git a/proctl/threads_linux_amd64.go b/proctl/threads_linux_amd64.go index 96e5dec2..8cfe9e25 100644 --- a/proctl/threads_linux_amd64.go +++ b/proctl/threads_linux_amd64.go @@ -1,6 +1,8 @@ package proctl import ( + "bytes" + "encoding/binary" "fmt" "os" "strconv" @@ -46,14 +48,98 @@ func (thread *ThreadContext) CurrentPC() (uint64, error) { return regs.PC(), nil } +// Sets a software breakpoint at addr, and stores it in the process wide +// break point table. Setting a break point must be thread specific due to +// ptrace actions needing the thread to be in a signal-delivery-stop in order +// to initiate any ptrace command. Otherwise, it really doesn't matter +// as we're only dealing with threads. +func (thread *ThreadContext) Break(addr uintptr) (*BreakPoint, error) { + var ( + int3 = []byte{0xCC} + f, l, fn = thread.Process.GoSymTable.PCToLine(uint64(addr)) + originalData = make([]byte, 1) + ) + + if fn == nil { + return nil, InvalidAddressError{address: addr} + } + + _, err := syscall.PtracePeekData(thread.Id, addr, originalData) + if err != nil { + fmt.Println("PEEK ERR") + return nil, err + } + + if bytes.Equal(originalData, int3) { + return nil, BreakPointExistsError{f, l, addr} + } + + _, err = syscall.PtracePokeData(thread.Id, addr, int3) + if err != nil { + fmt.Println("POKE ERR") + return nil, err + } + + breakpoint := &BreakPoint{ + FunctionName: fn.Name, + File: f, + Line: l, + Addr: uint64(addr), + OriginalData: originalData, + } + + thread.Process.BreakPoints[uint64(addr)] = breakpoint + + return breakpoint, nil +} + +// Clears a software breakpoint, and removes it from the process level +// break point table. +func (thread *ThreadContext) Clear(pc uint64) (*BreakPoint, error) { + bp, ok := thread.Process.BreakPoints[pc] + if !ok { + return nil, fmt.Errorf("No breakpoint currently set for %#v", pc) + } + + if _, err := syscall.PtracePokeData(thread.Id, uintptr(bp.Addr), bp.OriginalData); err != nil { + ps, err := parseProcessStatus(thread.Id) + if err != nil { + return nil, err + } + + if ps.state != STATUS_TRACE_STOP { + if err := syscall.Tgkill(thread.Process.Pid, thread.Id, syscall.SIGSTOP); err != nil { + return nil, err + } + if _, err := syscall.Wait4(thread.Id, nil, syscall.WALL, nil); err != nil { + return nil, err + } + if _, err := syscall.PtracePokeData(thread.Id, uintptr(bp.Addr), bp.OriginalData); err != nil { + return nil, err + } + } + } + + delete(thread.Process.BreakPoints, pc) + + return bp, nil +} + func (thread *ThreadContext) Continue() error { - // Stepping first will ensure we are able to continue - // past a breakpoint if that's currently where we are stopped. - err := thread.Step() + // Check whether we are stopped at a breakpoint, and + // if so, single step over it before continuing. + regs, err := thread.Registers() if err != nil { return err } + if _, ok := thread.Process.BreakPoints[regs.PC()-1]; ok { + err := thread.Step() + if err != nil { + return err + } + } + return syscall.PtraceCont(thread.Id, 0) } @@ -81,7 +167,7 @@ func (thread *ThreadContext) Step() (err error) { // Restore breakpoint now that we have passed it. defer func() { - _, err = thread.Process.Break(uintptr(bp.Addr)) + _, err = thread.Break(uintptr(bp.Addr)) }() } @@ -90,8 +176,11 @@ func (thread *ThreadContext) Step() (err error) { return fmt.Errorf("step failed: %s", err.Error()) } - _, _, err = wait(thread.Process, thread.Id, 0) + _, _, err = timeoutWait(thread.Id, 0) if err != nil { + if _, ok := err.(TimeoutError); ok { + return nil + } return fmt.Errorf("step failed: %s", err.Error()) } @@ -106,9 +195,7 @@ func (thread *ThreadContext) Next() (err error) { } if _, ok := thread.Process.BreakPoints[pc-1]; ok { - // Decrement the PC to be before - // the breakpoint instruction. - pc-- + pc-- // Decrement PC to account for BreakPoint } _, l, _ := thread.Process.GoSymTable.PCToLine(pc) @@ -126,7 +213,7 @@ func (thread *ThreadContext) Next() (err error) { return thread.CurrentPC() } - ret := thread.Process.ReturnAddressFromOffset(fde.ReturnAddressOffset(pc)) + ret := thread.ReturnAddressFromOffset(fde.ReturnAddressOffset(pc)) for { pc, err = step() if err != nil { @@ -134,13 +221,12 @@ func (thread *ThreadContext) Next() (err error) { } if !fde.Cover(pc) && pc != ret { - thread.continueToReturnAddress(pc, fde) + err := thread.continueToReturnAddress(pc, fde) if err != nil { - if ierr, ok := err.(InvalidAddressError); ok { - return ierr + if _, ok := err.(InvalidAddressError); !ok { + return err } } - pc, _ = thread.CurrentPC() } @@ -160,26 +246,25 @@ func (thread *ThreadContext) continueToReturnAddress(pc uint64, fde *frame.Frame // of this function. Therefore the function // has not had a chance to modify its' stack // and change our offset. - addr := thread.Process.ReturnAddressFromOffset(0) - bp, err := thread.Process.Break(uintptr(addr)) + addr := thread.ReturnAddressFromOffset(0) + bp, err := thread.Break(uintptr(addr)) if err != nil { if _, ok := err.(BreakPointExistsError); !ok { return err } } + bp.temp = true err = thread.Continue() if err != nil { return err } - _, _, err = wait(thread.Process, thread.Id, 0) - if err != nil { + if _, _, err := trapWait(thread.Process, thread.Id, 0); err != nil { return err } - err = thread.clearTempBreakpoint(bp.Addr) - if err != nil { + if err := thread.clearTempBreakpoint(bp.Addr); err != nil { return err } @@ -189,15 +274,30 @@ func (thread *ThreadContext) continueToReturnAddress(pc uint64, fde *frame.Frame return nil } +// Takes an offset from RSP and returns the address of the +// instruction the currect function is going to return to. +func (thread *ThreadContext) ReturnAddressFromOffset(offset int64) uint64 { + regs, err := thread.Registers() + if err != nil { + panic("Could not obtain register values") + } + + retaddr := int64(regs.Rsp) + offset + data := make([]byte, 8) + syscall.PtracePeekText(thread.Id, uintptr(retaddr), data) + return binary.LittleEndian.Uint64(data) +} + func (thread *ThreadContext) clearTempBreakpoint(pc uint64) error { if bp, ok := thread.Process.BreakPoints[pc]; ok { - regs, err := thread.Registers() + _, err := thread.Clear(bp.Addr) if err != nil { + fmt.Println("ERR", err) return err } // Reset program counter to our restored instruction. - bp, err = thread.Process.Clear(bp.Addr) + regs, err := thread.Registers() if err != nil { return err }