diff --git a/_fixtures/parallel_next.go b/_fixtures/parallel_next.go new file mode 100644 index 00000000..70338f5e --- /dev/null +++ b/_fixtures/parallel_next.go @@ -0,0 +1,21 @@ +package main + +import ( + "fmt" + "sync" +) + +func sayhi(n int, wg *sync.WaitGroup) { + fmt.Println("hi", n) + fmt.Println("hi", n) + wg.Done() +} + +func main() { + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go sayhi(i, &wg) + } + wg.Wait() +} diff --git a/proc/proc.go b/proc/proc.go index 5f1e72ae..64eba5de 100644 --- a/proc/proc.go +++ b/proc/proc.go @@ -252,7 +252,7 @@ func (dbp *Process) Next() error { return dbp.run(dbp.next) } -func (dbp *Process) next() error { +func (dbp *Process) next() (err error) { // Make sure we clean up the temp breakpoints created by thread.Next defer dbp.clearTempBreakpoints() @@ -260,63 +260,78 @@ func (dbp *Process) next() error { // blocked trying to read from a channel. This is so that // if control flow switches to that goroutine, we end up // somewhere useful instead of in runtime code. - chanRecvCount, err := dbp.setChanRecvBreakpoints() - if err != nil { + if _, err := dbp.setChanRecvBreakpoints(); err != nil { return err } + // Get the goroutine for the current thread. We will + // use it later in order to ensure we are on the same + // goroutine. g, err := dbp.CurrentThread.GetG() if err != nil { return err } - if g.DeferPC != 0 { - _, err = dbp.SetTempBreakpoint(g.DeferPC) - if err != nil { - return err - } - } - var goroutineExiting bool - var waitCount int - for _, th := range dbp.Threads { - // Ignore threads that aren't running go code. - if !th.blocked() { - waitCount++ - if err = th.SetNextBreakpoints(); err != nil { - if gerr, ok := err.(GoroutineExitingError); ok { - waitCount = waitCount - 1 + chanRecvCount - if gerr.goid == g.Id { - goroutineExiting = true - } - } else { - return err - } + threadNext := func(thread *Thread) error { + if err = thread.setNextBreakpoints(); err != nil { + switch t := err.(type) { + case ThreadBlockedError, NoReturnAddr: // Noop + case GoroutineExitingError: + goroutineExiting = t.goid == g.Id + default: + return err } } - if err = th.Continue(); err != nil { + return thread.Continue() + } + + // Make sure that we halt the process at the end of this + // function. We could get into a situation where we have + // started some, but not all threads. + defer func() { err = dbp.Halt() }() + + // Set next breakpoints and then continue each thread. + for _, th := range dbp.Threads { + if err := threadNext(th); err != nil { return err } } - for waitCount > 0 { - thread, err := dbp.trapWait(-1) - if err != nil { + for { + if _, err := dbp.trapWait(-1); err != nil { return err } - tg, err := thread.GetG() - if err != nil { - return err - } - // Make sure we're on the same goroutine, unless it has exited. - if tg.Id == g.Id || goroutineExiting { - if dbp.CurrentThread != thread { - dbp.SwitchThread(thread.Id) + // We need to wait for our goroutine to execute, which may not happen + // immediately. + // + // Loop through all threads, and for each stopped thread + // see if it is the thread that we care about (thread.g == original.g). + // If so, we're done. Otherwise set next temp breakpoints for + // each thread and continue them. The reason we do this is because + // if our goroutine is paused, we must execute other threads in order + // for them to get to a scheduling point, so they can pick up the + // goroutine we care about and begin executing it. + for _, thr := range dbp.Threads { + if !thr.Stopped() { + continue + } + tg, err := thr.GetG() + if err != nil { + return err + } + // Make sure we're on the same goroutine, unless it has exited. + if tg.Id == g.Id || goroutineExiting { + if dbp.CurrentThread != thr { + dbp.SwitchThread(thr.Id) + } + return nil + } + if err := threadNext(thr); err != nil { + return err } } - waitCount-- } - return dbp.Halt() } func (dbp *Process) setChanRecvBreakpoints() (int, error) { diff --git a/proc/proc_linux.go b/proc/proc_linux.go index 10f11d7b..662d1354 100644 --- a/proc/proc_linux.go +++ b/proc/proc_linux.go @@ -67,7 +67,7 @@ func (dbp *Process) Kill() (err error) { if dbp.exited { return nil } - if !stopped(dbp.Pid) { + if !dbp.Threads[dbp.Pid].Stopped() { return errors.New("process must be stopped in order to kill it") } if err = sys.Kill(-dbp.Pid, sys.SIGKILL); err != nil { @@ -322,14 +322,6 @@ func status(pid int) rune { return state } -func stopped(pid int) bool { - state := status(pid) - if state == STATUS_TRACE_STOP { - return true - } - return false -} - func wait(pid, tgid, options int) (int, *sys.WaitStatus, error) { var s sys.WaitStatus if (pid != tgid) || (options != 0) { diff --git a/proc/proc_test.go b/proc/proc_test.go index 1feae022..b9c21376 100644 --- a/proc/proc_test.go +++ b/proc/proc_test.go @@ -130,6 +130,9 @@ func TestHalt(t *testing.T) { // actually stopped, err will not be nil if the process // is still running. for _, th := range p.Threads { + if !th.Stopped() { + t.Fatal("expected thread to be stopped, but was not") + } if th.running != false { t.Fatal("expected running = false for thread", th.Id) } @@ -297,6 +300,36 @@ func TestNextGeneral(t *testing.T) { testnext("testnextprog", testcases, "main.testnext", t) } +func TestNextConcurrent(t *testing.T) { + testcases := []nextTest{ + {9, 10}, + {10, 11}, + } + withTestProcess("parallel_next", t, func(p *Process, fixture protest.Fixture) { + _, err := setFunctionBreakpoint(p, "main.sayhi") + assertNoError(err, t, "SetBreakpoint") + assertNoError(p.Continue(), t, "Continue") + f, ln := currentLineNumber(p, t) + initV, err := p.EvalVariable("n") + assertNoError(err, t, "EvalVariable") + for _, tc := range testcases { + if ln != tc.begin { + t.Fatalf("Program not stopped at correct spot expected %d was %s:%d", tc.begin, filepath.Base(f), ln) + } + assertNoError(p.Next(), t, "Next() returned an error") + f, ln = currentLineNumber(p, t) + if ln != tc.end { + t.Fatalf("Program did not continue to correct next location expected %d was %s:%d", tc.end, filepath.Base(f), ln) + } + v, err := p.EvalVariable("n") + assertNoError(err, t, "EvalVariable") + if v.Value != initV.Value { + t.Fatal("Did not end up on same goroutine") + } + } + }) +} + func TestNextGoroutine(t *testing.T) { testcases := []nextTest{ {47, 42}, diff --git a/proc/stack.go b/proc/stack.go index 2959d3f0..4ae99507 100644 --- a/proc/stack.go +++ b/proc/stack.go @@ -5,6 +5,14 @@ import ( "fmt" ) +type NoReturnAddr struct { + fn string +} + +func (nra NoReturnAddr) Error() string { + return fmt.Sprintf("could not find return address for %s", nra.fn) +} + // Takes an offset from RSP and returns the address of the // instruction the current function is going to return to. func (thread *Thread) ReturnAddress() (uint64, error) { @@ -13,7 +21,7 @@ func (thread *Thread) ReturnAddress() (uint64, error) { return 0, err } if len(locations) < 2 { - return 0, fmt.Errorf("could not find return address for %s", locations[0].Fn.BaseName()) + return 0, NoReturnAddr{locations[0].Fn.BaseName()} } return locations[1].PC, nil } diff --git a/proc/threads.go b/proc/threads.go index 11982209..9a918fc7 100644 --- a/proc/threads.go +++ b/proc/threads.go @@ -115,6 +115,12 @@ func (thread *Thread) Location() (*Location, error) { return &Location{PC: pc, File: f, Line: l, Fn: fn}, nil } +type ThreadBlockedError struct{} + +func (tbe ThreadBlockedError) Error() string { + return "" +} + // Set breakpoints for potential next lines. // // There are two modes of operation for this method. First, @@ -129,11 +135,23 @@ func (thread *Thread) Location() (*Location, error) { // at every single line within the current function, and // another at the functions return address, in case we're at // the end. -func (thread *Thread) SetNextBreakpoints() (err error) { +func (thread *Thread) setNextBreakpoints() (err error) { + if thread.blocked() { + return ThreadBlockedError{} + } curpc, err := thread.PC() if err != nil { return err } + g, err := thread.GetG() + if err != nil { + return err + } + if g.DeferPC != 0 { + if _, err = thread.dbp.SetTempBreakpoint(g.DeferPC); err != nil { + return err + } + } // Grab info on our current stack frame. Used to determine // whether we may be stepping outside of the current function. @@ -148,15 +166,11 @@ func (thread *Thread) SetNextBreakpoints() (err error) { return err } if filepath.Ext(loc.File) == ".go" { - if err = thread.next(curpc, fde, loc.File, loc.Line); err != nil { - return err - } + err = thread.next(curpc, fde, loc.File, loc.Line) } else { - if err = thread.cnext(curpc, fde); err != nil { - return err - } + err = thread.cnext(curpc, fde) } - return nil + return err } // Go routine is exiting. @@ -278,3 +292,10 @@ func (thread *Thread) GetG() (g *G, err error) { } return } + +// Returns whether the thread is stopped at +// the operating system level. Actual implementation +// is OS dependant, look in OS thread file. +func (thread *Thread) Stopped() bool { + return thread.stopped() +} diff --git a/proc/threads_darwin.c b/proc/threads_darwin.c index d3ef14c1..51d0e97a 100644 --- a/proc/threads_darwin.c +++ b/proc/threads_darwin.c @@ -123,3 +123,15 @@ clear_trap_flag(thread_act_t thread) { return thread_set_state(thread, x86_THREAD_STATE64, (thread_state_t)®s, count); } + +int +thread_blocked(thread_act_t thread) { + kern_return_t kret; + struct thread_basic_info info; + unsigned int info_count = THREAD_BASIC_INFO_COUNT; + + kret = thread_info((thread_t)thread, THREAD_BASIC_INFO, (thread_info_t)&info, &info_count); + if (kret != KERN_SUCCESS) return -1; + + return info.suspend_count; +} diff --git a/proc/threads_darwin.go b/proc/threads_darwin.go index c41c4fa5..e4470262 100644 --- a/proc/threads_darwin.go +++ b/proc/threads_darwin.go @@ -12,14 +12,22 @@ type OSSpecificDetails struct { registers C.x86_thread_state64_t } -func (t *Thread) Halt() error { - var kret C.kern_return_t - kret = C.thread_suspend(t.os.thread_act) - if kret != C.KERN_SUCCESS { - return fmt.Errorf("could not suspend thread %d", t.Id) +func (t *Thread) Halt() (err error) { + defer func() { + if err == nil { + t.running = false + } + }() + if t.Stopped() { + return } - t.running = false - return nil + kret := C.thread_suspend(t.os.thread_act) + if kret != C.KERN_SUCCESS { + errStr := C.GoString(C.mach_error_string(C.mach_error_t(kret))) + err = fmt.Errorf("could not suspend thread %d %s", t.Id, errStr) + return + } + return } func (t *Thread) singleStep() error { @@ -50,10 +58,13 @@ func (t *Thread) resume() error { return nil } -func (t *Thread) blocked() bool { +func (thread *Thread) blocked() bool { // TODO(dp) cache the func pc to remove this lookup - pc, _ := t.PC() - fn := t.dbp.goSymTable.PCToFunc(pc) + pc, err := thread.PC() + if err != nil { + return false + } + fn := thread.dbp.goSymTable.PCToFunc(pc) if fn == nil { return false } @@ -65,6 +76,10 @@ func (t *Thread) blocked() bool { } } +func (thread *Thread) stopped() bool { + return C.thread_blocked(thread.os.thread_act) > C.int(0) +} + func (thread *Thread) writeMemory(addr uintptr, data []byte) (int, error) { if len(data) == 0 { return 0, nil diff --git a/proc/threads_darwin.h b/proc/threads_darwin.h index 156c410f..4a64d49d 100644 --- a/proc/threads_darwin.h +++ b/proc/threads_darwin.h @@ -30,3 +30,6 @@ set_registers(mach_port_name_t, x86_thread_state64_t*); kern_return_t get_identity(mach_port_name_t, thread_identifier_info_data_t *); + +int +thread_blocked(thread_act_t thread); diff --git a/proc/threads_linux.go b/proc/threads_linux.go index db8a71a0..ecfa7a28 100644 --- a/proc/threads_linux.go +++ b/proc/threads_linux.go @@ -12,20 +12,31 @@ type OSSpecificDetails struct { registers sys.PtraceRegs } -func (t *Thread) Halt() error { - if stopped(t.Id) { - return nil +func (t *Thread) Halt() (err error) { + defer func() { + if err == nil { + t.running = false + } + }() + if t.Stopped() { + return } - err := sys.Tgkill(t.dbp.Pid, t.Id, sys.SIGSTOP) + err = sys.Tgkill(t.dbp.Pid, t.Id, sys.SIGSTOP) if err != nil { - return fmt.Errorf("halt err %s on thread %d", err, t.Id) + err = fmt.Errorf("halt err %s on thread %d", err, t.Id) + return } _, _, err = wait(t.Id, t.dbp.Pid, 0) if err != nil { - return fmt.Errorf("wait err %s on thread %d", err, t.Id) + err = fmt.Errorf("wait err %s on thread %d", err, t.Id) + return } - t.running = false - return nil + return +} + +func (thread *Thread) stopped() bool { + state := status(thread.Id) + return state == STATUS_TRACE_STOP } func (t *Thread) resume() (err error) {