Improve overall thread coordination

This commit is contained in:
Derek Parker
2014-11-07 23:44:24 -06:00
parent 4483b17bd6
commit 6b2ee09163
4 changed files with 243 additions and 178 deletions

View File

@ -3,9 +3,7 @@
package proctl
import (
"bytes"
"debug/gosym"
"encoding/binary"
"fmt"
"os"
"os/exec"
@ -40,6 +38,7 @@ type BreakPoint struct {
Line int
Addr uint64
OriginalData []byte
temp bool
}
type BreakPointExistsError struct {
@ -184,60 +183,14 @@ func (dbp *DebuggedProcess) addThread(tid int) (*ThreadContext, error) {
return dbp.Threads[tid], nil
}
// Sets a breakpoint in the running process.
// Sets a breakpoint in the current thread.
func (dbp *DebuggedProcess) Break(addr uintptr) (*BreakPoint, error) {
var (
int3 = []byte{0xCC}
f, l, fn = dbp.GoSymTable.PCToLine(uint64(addr))
originalData = make([]byte, 1)
)
if fn == nil {
return nil, InvalidAddressError{address: addr}
}
_, err := syscall.PtracePeekData(dbp.CurrentThread.Id, addr, originalData)
if err != nil {
return nil, err
}
if bytes.Equal(originalData, int3) {
return nil, BreakPointExistsError{f, l, addr}
}
_, err = syscall.PtracePokeData(dbp.CurrentThread.Id, addr, int3)
if err != nil {
return nil, err
}
breakpoint := &BreakPoint{
FunctionName: fn.Name,
File: f,
Line: l,
Addr: uint64(addr),
OriginalData: originalData,
}
dbp.BreakPoints[uint64(addr)] = breakpoint
return breakpoint, nil
return dbp.CurrentThread.Break(addr)
}
// Clears a breakpoint.
// Clears a breakpoint in the current thread.
func (dbp *DebuggedProcess) Clear(pc uint64) (*BreakPoint, error) {
bp, ok := dbp.BreakPoints[pc]
if !ok {
return nil, fmt.Errorf("No breakpoint currently set for %#v", pc)
}
_, err := syscall.PtracePokeData(dbp.CurrentThread.Id, uintptr(bp.Addr), bp.OriginalData)
if err != nil {
return nil, err
}
delete(dbp.BreakPoints, pc)
return bp, nil
return dbp.CurrentThread.Clear(pc)
}
// Returns the status of the current main thread context.
@ -288,8 +241,10 @@ func (dbp *DebuggedProcess) Step() (err error) {
func (dbp *DebuggedProcess) Next() error {
for _, thread := range dbp.Threads {
err := thread.Next()
if _, ok := err.(ProcessExitedError); !ok {
return err
if err != nil {
if _, ok := err.(ProcessExitedError); !ok {
return err
}
}
}
@ -305,7 +260,7 @@ func (dbp *DebuggedProcess) Continue() error {
}
}
_, _, err := wait(dbp, -1, 0)
_, _, err := trapWait(dbp, -1, 0)
if err != nil {
if _, ok := err.(ProcessExitedError); !ok {
return err
@ -402,20 +357,6 @@ func (dbp *DebuggedProcess) obtainGoSymbols(wg *sync.WaitGroup) {
dbp.GoSymTable = tab
}
// Takes an offset from RSP and returns the address of the
// instruction the currect function is going to return to.
func (dbp *DebuggedProcess) ReturnAddressFromOffset(offset int64) uint64 {
regs, err := dbp.Registers()
if err != nil {
panic("Could not obtain register values")
}
retaddr := int64(regs.Rsp) + offset
data := make([]byte, 8)
syscall.PtracePeekText(dbp.Pid, uintptr(retaddr), data)
return binary.LittleEndian.Uint64(data)
}
type ProcessExitedError struct {
pid int
}
@ -424,110 +365,82 @@ func (pe ProcessExitedError) Error() string {
return fmt.Sprintf("process %d has exited", pe.pid)
}
func wait(dbp *DebuggedProcess, p int, options int) (int, *syscall.WaitStatus, error) {
func trapWait(dbp *DebuggedProcess, p int, options int) (int, *syscall.WaitStatus, error) {
var status syscall.WaitStatus
for {
pid, status, err := timeoutWait(p, options)
pid, err := syscall.Wait4(p, &status, syscall.WALL|options, nil)
if err != nil {
if _, ok := err.(TimeoutError); ok {
return p, nil, nil
}
return -1, nil, fmt.Errorf("wait err %s %d", err, pid)
}
thread, threadtraced := dbp.Threads[pid]
if threadtraced {
thread.Status = status
if !threadtraced {
continue
}
thread.Status = &status
if status.Exited() {
if pid == dbp.Pid {
return 0, nil, ProcessExitedError{pid}
}
delete(dbp.Threads, pid)
continue
}
if status.StopSignal() == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_CLONE {
// A traced thread has cloned a new thread, grab the pid and
// add it to our list of traced threads.
msg, err := syscall.PtraceGetEventMsg(pid)
if err != nil {
return 0, nil, fmt.Errorf("could not get event message: %s", err)
}
_, err = dbp.addThread(int(msg))
if err != nil {
if _, ok := err.(ProcessExitedError); ok {
continue
}
return 0, nil, err
}
err = syscall.PtraceCont(int(msg), 0)
if err != nil {
return 0, nil, fmt.Errorf("could not continue new thread %d %s", msg, err)
}
err = syscall.PtraceCont(pid, 0)
if err != nil {
return 0, nil, fmt.Errorf("could not continue stopped thread %d %s", pid, err)
}
continue
switch status.TrapCause() {
case syscall.PTRACE_EVENT_CLONE:
addNewThread(dbp, pid)
default:
pc, err := thread.CurrentPC()
if err != nil {
return -1, nil, fmt.Errorf("could not get current pc %s", err)
}
if pid != dbp.CurrentThread.Id {
fmt.Printf("changed thread context from %d to %d\n", dbp.CurrentThread.Id, pid)
dbp.CurrentThread = thread
}
pc, _ := thread.CurrentPC()
// Check to see if we have hit a breakpoint
// that we know about.
if _, ok := dbp.BreakPoints[pc-1]; ok {
// Loop through all threads and ensure that we
// stop the rest of them, so that by the time
// we return control to the user, all threads
// are inactive. We send SIGSTOP and ensure all
// threads are in in signal-delivery-stop mode.
for _, th := range dbp.Threads {
if th.Id == pid {
// This thread is already stopped.
continue
}
ps, err := parseProcessStatus(pid)
if err != nil {
return -1, nil, err
}
if ps.state == STATUS_TRACE_STOP {
continue
}
err = syscall.Tgkill(dbp.Pid, th.Id, syscall.SIGALRM)
if err != nil {
return -1, nil, err
}
pid, err := syscall.Wait4(th.Id, nil, syscall.WALL, nil)
if err != nil {
return -1, nil, fmt.Errorf("wait err %s %d", err, pid)
}
// Check to see if we have hit a breakpoint.
if bp, ok := dbp.BreakPoints[pc-1]; ok {
if !bp.temp {
handleBreakPoint(dbp, thread, pid)
}
return pid, status, nil
return pid, &status, nil
}
}
if status.Stopped() {
if pid == dbp.Pid {
return pid, status, nil
}
// The thread has stopped, but has not hit a breakpoint.
// Continue the thread without returning control back
// to the console.
syscall.PtraceCont(pid, 0)
}
}
}
func addNewThread(dbp *DebuggedProcess, pid int) error {
// A traced thread has cloned a new thread, grab the pid and
// add it to our list of traced threads.
msg, err := syscall.PtraceGetEventMsg(pid)
if err != nil {
return fmt.Errorf("could not get event message: %s", err)
}
fmt.Println("new thread spawned", msg)
_, err = dbp.addThread(int(msg))
if err != nil {
return err
}
err = syscall.PtraceCont(int(msg), 0)
if err != nil {
return fmt.Errorf("could not continue new thread %d %s", msg, err)
}
err = syscall.PtraceCont(pid, 0)
if err != nil {
return fmt.Errorf("could not continue stopped thread %d %s", pid, err)
}
return nil
}
type waitstats struct {
pid int
status *syscall.WaitStatus
@ -550,6 +463,17 @@ func timeoutWait(pid int, options int) (int, *syscall.WaitStatus, error) {
errchan = make(chan error)
)
if pid > 0 {
ps, err := parseProcessStatus(pid)
if err != nil {
return -1, nil, err
}
if ps.state == STATUS_SLEEPING {
return 0, nil, nil
}
}
go func(pid int) {
pid, err := syscall.Wait4(pid, &status, syscall.WALL|options, nil)
if err != nil {
@ -562,16 +486,57 @@ func timeoutWait(pid int, options int) (int, *syscall.WaitStatus, error) {
select {
case s := <-statchan:
return s.pid, s.status, nil
case <-time.After(1 * time.Second):
ps, err := parseProcessStatus(pid)
if err != nil {
return -1, nil, err
case <-time.After(2 * time.Second):
if pid > 0 {
ps, err := parseProcessStatus(pid)
if err != nil {
return -1, nil, err
}
syscall.Tgkill(ps.ppid, ps.pid, syscall.SIGSTOP)
}
syscall.Tgkill(ps.ppid, ps.pid, syscall.SIGSTOP)
return pid, nil, TimeoutError{pid}
return 0, nil, TimeoutError{pid}
case err := <-errchan:
return -1, nil, err
}
}
func handleBreakPoint(dbp *DebuggedProcess, thread *ThreadContext, pid int) error {
if pid != dbp.CurrentThread.Id {
fmt.Printf("thread context changed from %d to %d\n", dbp.CurrentThread.Id, pid)
dbp.CurrentThread = thread
}
// Loop through all threads and ensure that we
// stop the rest of them, so that by the time
// we return control to the user, all threads
// are inactive. We send SIGSTOP and ensure all
// threads are in in signal-delivery-stop mode.
for _, th := range dbp.Threads {
if th.Id == pid {
// This thread is already stopped.
continue
}
ps, err := parseProcessStatus(th.Id)
if err != nil {
return err
}
if ps.state == STATUS_TRACE_STOP {
continue
}
err = syscall.Tgkill(dbp.Pid, th.Id, syscall.SIGSTOP)
if err != nil {
return err
}
pid, err := syscall.Wait4(th.Id, nil, syscall.WALL, nil)
if err != nil {
return fmt.Errorf("wait err %s %d", err, pid)
}
}
return nil
}