diff --git a/proctl/proctl_test.go b/proctl/proctl_test.go index 57f57d44..b8bc88c0 100644 --- a/proctl/proctl_test.go +++ b/proctl/proctl_test.go @@ -7,6 +7,8 @@ import ( "testing" ) +type testfunc func(p *DebuggedProcess) + func dataAtAddr(pid int, addr uint64) ([]byte, error) { data := make([]byte, 1) _, err := syscall.PtracePeekData(pid, uintptr(addr), data) @@ -26,6 +28,22 @@ func getRegisters(p *DebuggedProcess, t *testing.T) *syscall.PtraceRegs { return regs } +func withTestProcess(name string, t *testing.T, fn testfunc) { + cmd, err := StartTestProcess(name) + if err != nil { + t.Fatal("Starting test process:", err) + } + + pid := cmd.Process.Pid + p, err := NewDebugProcess(pid) + if err != nil { + t.Fatal("NewDebugProcess():", err) + } + defer cmd.Process.Kill() + + fn(p) +} + func StartTestProcess(name string) (*exec.Cmd, error) { cmd := exec.Command("../fixtures/" + name) @@ -38,187 +56,123 @@ func StartTestProcess(name string) (*exec.Cmd, error) { } func TestAttachProcess(t *testing.T) { - cmd, err := StartTestProcess("testprog") - if err != nil { - t.Fatal("Starting test process:", err) - } - - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } - - if !p.ProcessState.Sys().(syscall.WaitStatus).Stopped() { - t.Errorf("Process was not stopped correctly") - } - - cmd.Process.Kill() + withTestProcess("testprog", t, func(p *DebuggedProcess) { + if !p.ProcessState.Sys().(syscall.WaitStatus).Stopped() { + t.Errorf("Process was not stopped correctly") + } + }) } func TestStep(t *testing.T) { - cmd, err := StartTestProcess("testprog") - if err != nil { - t.Fatal("Starting test process:", err) - } + withTestProcess("testprog", t, func(p *DebuggedProcess) { + regs := getRegisters(p, t) + rip := regs.PC() - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } + err := p.Step() + if err != nil { + t.Fatal("Step():", err) + } - regs := getRegisters(p, t) - rip := regs.PC() + regs = getRegisters(p, t) - err = p.Step() - if err != nil { - t.Fatal("Step():", err) - } - - regs = getRegisters(p, t) - - if rip >= regs.PC() { - t.Errorf("Expected %#v to be greater than %#v", regs.PC(), rip) - } - - cmd.Process.Kill() + if rip >= regs.PC() { + t.Errorf("Expected %#v to be greater than %#v", regs.PC(), rip) + } + }) } func TestContinue(t *testing.T) { - cmd, err := StartTestProcess("continuetestprog") - if err != nil { - t.Fatal("Starting test process:", err) - } + withTestProcess("continuetestprog", t, func(p *DebuggedProcess) { + if p.ProcessState.Exited() { + t.Fatal("Process already exited") + } - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } + err := p.Continue() + if err != nil { + t.Fatal("Continue():", err) + } - if p.ProcessState.Exited() { - t.Fatal("Process already exited") - } - - err = p.Continue() - if err != nil { - t.Fatal("Continue():", err) - } - - if !p.ProcessState.Success() { - t.Fatal("Process did not exit successfully") - } + if !p.ProcessState.Success() { + t.Fatal("Process did not exit successfully") + } + }) } func TestBreakPoint(t *testing.T) { - cmd, err := StartTestProcess("testprog") - if err != nil { - t.Fatal("Starting test process:", err) - } + withTestProcess("testprog", t, func(p *DebuggedProcess) { + sleepytimefunc := p.GoSymTable.LookupFunc("main.sleepytime") + sleepyaddr := sleepytimefunc.Entry - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } + bp, err := p.Break(uintptr(sleepyaddr)) + if err != nil { + t.Fatal("Break():", err) + } - sleepytimefunc := p.GoSymTable.LookupFunc("main.sleepytime") - sleepyaddr := sleepytimefunc.Entry + breakpc := bp.Addr + 1 + err = p.Continue() + if err != nil { + t.Fatal("Continue():", err) + } - bp, err := p.Break(uintptr(sleepyaddr)) - if err != nil { - t.Fatal("Break():", err) - } + regs := getRegisters(p, t) - breakpc := bp.Addr + 1 - err = p.Continue() - if err != nil { - t.Fatal("Continue():", err) - } + pc := regs.PC() + if pc != breakpc { + t.Fatalf("Break not respected:\nPC:%d\nFN:%d\n", pc, breakpc) + } - regs := getRegisters(p, t) + err = p.Step() + if err != nil { + t.Fatal(err) + } - pc := regs.PC() - if pc != breakpc { - t.Fatalf("Break not respected:\nPC:%d\nFN:%d\n", pc, breakpc) - } + regs = getRegisters(p, t) - err = p.Step() - if err != nil { - t.Fatal(err) - } - - regs = getRegisters(p, t) - - pc = regs.PC() - if pc == breakpc { - t.Fatalf("Step not respected:\nPC:%d\nFN:%d\n", pc, breakpc) - } - - cmd.Process.Kill() + pc = regs.PC() + if pc == breakpc { + t.Fatalf("Step not respected:\nPC:%d\nFN:%d\n", pc, breakpc) + } + }) } func TestBreakPointWithNonExistantFunction(t *testing.T) { - cmd, err := StartTestProcess("testprog") - if err != nil { - t.Fatal("Starting test process:", err) - } - - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } - - _, err = p.Break(uintptr(0)) - if err == nil { - t.Fatal("Should not be able to break at non existant function") - } - - cmd.Process.Kill() + withTestProcess("testprog", t, func(p *DebuggedProcess) { + _, err := p.Break(uintptr(0)) + if err == nil { + t.Fatal("Should not be able to break at non existant function") + } + }) } func TestClearBreakPoint(t *testing.T) { - cmd, err := StartTestProcess("testprog") - if err != nil { - t.Fatal("Starting test process:", err) - } + withTestProcess("testprog", t, func(p *DebuggedProcess) { + fn := p.GoSymTable.LookupFunc("main.sleepytime") + bp, err := p.Break(uintptr(fn.Entry)) + if err != nil { + t.Fatal("Break():", err) + } - pid := cmd.Process.Pid - p, err := NewDebugProcess(pid) - if err != nil { - t.Fatal("NewDebugProcess():", err) - } + int3, err := dataAtAddr(p.Pid, bp.Addr) + if err != nil { + t.Fatal(err) + } - fn := p.GoSymTable.LookupFunc("main.sleepytime") - bp, err := p.Break(uintptr(fn.Entry)) - if err != nil { - t.Fatal("Break():", err) - } + bp, err = p.Clear(fn.Entry) + if err != nil { + t.Fatal("Break():", err) + } - int3, err := dataAtAddr(pid, bp.Addr) - if err != nil { - t.Fatal(err) - } + data, err := dataAtAddr(p.Pid, bp.Addr) + if err != nil { + t.Fatal(err) + } - bp, err = p.Clear(fn.Entry) - if err != nil { - t.Fatal("Break():", err) - } + if bytes.Equal(data, int3) { + t.Fatalf("Breakpoint was not cleared data: %#v, int3: %#v", data, int3) + } - data, err := dataAtAddr(pid, bp.Addr) - if err != nil { - t.Fatal(err) - } - - if bytes.Equal(data, int3) { - t.Fatalf("Breakpoint was not cleared data: %#v, int3: %#v", data, int3) - } - - if len(p.BreakPoints) != 0 { - t.Fatal("Breakpoint not removed internally") - } - - cmd.Process.Kill() + if len(p.BreakPoints) != 0 { + t.Fatal("Breakpoint not removed internally") + } + }) }