diff --git a/Makefile b/Makefile index e1add3be..c32fa55a 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ endif test: ifeq "$(UNAME)" "Darwin" - go test $(PREFIX)/command $(PREFIX)/dwarf/frame $(PREFIX)/dwarf/op $(PREFIX)/dwarf/util + go test $(PREFIX)/command $(PREFIX)/dwarf/frame $(PREFIX)/dwarf/op $(PREFIX)/dwarf/util $(PREFIX)/source $(PREFIX)/dwarf/line cd proctl && go test -c $(PREFIX)/proctl && codesign -s $(CERT) ./proctl.test && ./proctl.test && rm ./proctl.test else go test ./... diff --git a/_fixtures/testprog.go b/_fixtures/testprog.go index 5715241c..3a46959f 100644 --- a/_fixtures/testprog.go +++ b/_fixtures/testprog.go @@ -6,14 +6,14 @@ import ( "time" ) -func sleepytime() { - time.Sleep(time.Millisecond) -} - func helloworld() { fmt.Println("Hello, World!") } +func sleepytime() { + time.Sleep(time.Millisecond) +} + func main() { for { sleepytime() diff --git a/_fixtures/testvisitorprog.go b/_fixtures/testvisitorprog.go new file mode 100644 index 00000000..a5865d65 --- /dev/null +++ b/_fixtures/testvisitorprog.go @@ -0,0 +1,59 @@ +package main + +import "fmt" + +func main() { + for { + for i := 0; i < 5; i++ { + if i == 0 { + fmt.Println("it is zero!") + } else if i == 1 { + fmt.Println("it is one") + } else { + fmt.Println("wat") + } + switch i { + case 3: + fmt.Println("three") + case 4: + fmt.Println("four") + } + } + fmt.Println("done") + } + { + fmt.Println("useless line") + } + fmt.Println("end") +} + +func noop() { + var ( + i = 1 + j = 2 + ) + + if j == 3 { + fmt.Println(i) + } + + fmt.Println(j) +} + +func looptest() { + for { + fmt.Println("wat") + if false { + fmt.Println("uh, wat") + break + } + } + fmt.Println("dun") +} + +func endlesslooptest() { + for { + fmt.Println("foo") + fmt.Println("foo") + } +} diff --git a/dwarf/line/line_parser.go b/dwarf/line/line_parser.go new file mode 100644 index 00000000..aa9e80f8 --- /dev/null +++ b/dwarf/line/line_parser.go @@ -0,0 +1,95 @@ +package line + +import ( + "bytes" + "encoding/binary" + + "github.com/derekparker/delve/dwarf/util" +) + +type DebugLinePrologue struct { + Length uint32 + Version uint16 + PrologueLength uint32 + MinInstrLength uint8 + InitialIsStmt uint8 + LineBase int8 + LineRange uint8 + OpcodeBase uint8 + StdOpLengths []uint8 +} + +type DebugLineInfo struct { + Prologue *DebugLinePrologue + IncludeDirs []string + FileNames []*FileEntry + Instructions []byte +} + +type FileEntry struct { + Name string + DirIdx uint64 + LastModTime uint64 + Length uint64 +} + +func Parse(data []byte) *DebugLineInfo { + var ( + dbl = new(DebugLineInfo) + buf = bytes.NewBuffer(data) + ) + + parseDebugLinePrologue(dbl, buf) + parseIncludeDirs(dbl, buf) + parseFileEntries(dbl, buf) + dbl.Instructions = buf.Bytes() + + return dbl +} + +func parseDebugLinePrologue(dbl *DebugLineInfo, buf *bytes.Buffer) { + p := new(DebugLinePrologue) + + p.Length = binary.LittleEndian.Uint32(buf.Next(4)) + p.Version = binary.LittleEndian.Uint16(buf.Next(2)) + p.PrologueLength = binary.LittleEndian.Uint32(buf.Next(4)) + p.MinInstrLength = uint8(buf.Next(1)[0]) + p.InitialIsStmt = uint8(buf.Next(1)[0]) + p.LineBase = int8(buf.Next(1)[0]) + p.LineRange = uint8(buf.Next(1)[0]) + p.OpcodeBase = uint8(buf.Next(1)[0]) + + p.StdOpLengths = make([]uint8, p.OpcodeBase-1) + binary.Read(buf, binary.LittleEndian, &p.StdOpLengths) + + dbl.Prologue = p +} + +func parseIncludeDirs(info *DebugLineInfo, buf *bytes.Buffer) { + for { + str, _ := util.ParseString(buf) + if str == "" { + break + } + + info.IncludeDirs = append(info.IncludeDirs, str) + } +} + +func parseFileEntries(info *DebugLineInfo, buf *bytes.Buffer) { + for { + entry := new(FileEntry) + + name, _ := util.ParseString(buf) + if name == "" { + break + } + + entry.Name = name + entry.DirIdx, _ = util.DecodeULEB128(buf) + entry.LastModTime, _ = util.DecodeULEB128(buf) + entry.Length, _ = util.DecodeULEB128(buf) + + info.FileNames = append(info.FileNames, entry) + } +} diff --git a/dwarf/line/line_parser_test.go b/dwarf/line/line_parser_test.go new file mode 100644 index 00000000..4e558109 --- /dev/null +++ b/dwarf/line/line_parser_test.go @@ -0,0 +1,103 @@ +package line + +import ( + "debug/elf" + "debug/macho" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/davecheney/profile" +) + +func grabDebugLineSection(p string, t *testing.T) []byte { + f, err := os.Open(p) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + ef, err := elf.NewFile(f) + if err == nil { + data, _ := ef.Section(".debug_line").Data() + return data + } + + mf, _ := macho.NewFile(f) + data, _ := mf.Section("__debug_line").Data() + + return data +} + +func TestDebugLinePrologueParser(t *testing.T) { + // Test against known good values, from readelf --debug-dump=rawline _fixtures/testnextprog + p, err := filepath.Abs("../../_fixtures/testnextprog") + if err != nil { + t.Fatal(err) + } + + err = exec.Command("go", "build", "-gcflags=-N -l", "-o", p, p+".go").Run() + if err != nil { + t.Fatal("Could not compile test file", p, err) + } + defer os.Remove(p) + data := grabDebugLineSection(p, t) + dbl := Parse(data) + prologue := dbl.Prologue + + if prologue.Version != uint16(2) { + t.Fatal("Version not parsed correctly", prologue.Version) + } + + if prologue.MinInstrLength != uint8(1) { + t.Fatal("Minimun Instruction Length not parsed correctly", prologue.MinInstrLength) + } + + if prologue.InitialIsStmt != uint8(1) { + t.Fatal("Initial value of 'is_stmt' not parsed correctly", prologue.InitialIsStmt) + } + + if prologue.LineBase != int8(-1) { + t.Fatal("Line base not parsed correctly", prologue.LineBase) + } + + if prologue.LineRange != uint8(4) { + t.Fatal("Line Range not parsed correctly", prologue.LineRange) + } + + if prologue.OpcodeBase != uint8(10) { + t.Fatal("Opcode Base not parsed correctly", prologue.OpcodeBase) + } + + lengths := []uint8{0, 1, 1, 1, 1, 0, 0, 0, 1} + for i, l := range prologue.StdOpLengths { + if l != lengths[i] { + t.Fatal("Length not parsed correctly", l) + } + } + + if len(dbl.IncludeDirs) != 0 { + t.Fatal("Include dirs not parsed correctly") + } + + if !strings.Contains(dbl.FileNames[0].Name, "/delve/_fixtures/testnextprog.go") { + t.Fatal("First file entry not parsed correctly") + } +} + +func BenchmarkLineParser(b *testing.B) { + defer profile.Start(profile.MemProfile).Stop() + p, err := filepath.Abs("../../_fixtures/testnextprog") + if err != nil { + b.Fatal(err) + } + + data := grabDebugLineSection(p, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Parse(data) + } +} diff --git a/dwarf/line/state_machine.go b/dwarf/line/state_machine.go new file mode 100644 index 00000000..0235f255 --- /dev/null +++ b/dwarf/line/state_machine.go @@ -0,0 +1,223 @@ +package line + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/derekparker/delve/dwarf/util" +) + +type Location struct { + File string + Line int + Address uint64 + Delta int +} + +type StateMachine struct { + dbl *DebugLineInfo + file string + line int + address uint64 + column uint + isStmt bool + basicBlock bool + endSeq bool + lastWasStandard bool + lastDelta int +} + +type opcodefn func(*StateMachine, *bytes.Buffer) + +// Special opcodes +const ( + DW_LNS_copy = 1 + DW_LNS_advance_pc = 2 + DW_LNS_advance_line = 3 + DW_LNS_set_file = 4 + DW_LNS_set_column = 5 + DW_LNS_negate_stmt = 6 + DW_LNS_set_basic_block = 7 + DW_LNS_const_add_pc = 8 + DW_LNS_fixed_advance_pc = 9 +) + +// Extended opcodes +const ( + DW_LINE_end_sequence = 1 + DW_LINE_set_address = 2 + DW_LINE_define_file = 3 +) + +var standardopcodes = map[byte]opcodefn{ + DW_LNS_copy: copyfn, + DW_LNS_advance_pc: advancepc, + DW_LNS_advance_line: advanceline, + DW_LNS_set_file: setfile, + DW_LNS_set_column: setcolumn, + DW_LNS_negate_stmt: negatestmt, + DW_LNS_set_basic_block: setbasicblock, + DW_LNS_const_add_pc: constaddpc, + DW_LNS_fixed_advance_pc: fixedadvancepc, +} + +var extendedopcodes = map[byte]opcodefn{ + DW_LINE_end_sequence: endsequence, + DW_LINE_set_address: setaddress, + DW_LINE_define_file: definefile, +} + +func newStateMachine(dbl *DebugLineInfo) *StateMachine { + return &StateMachine{dbl: dbl, file: dbl.FileNames[0].Name, line: 1} +} + +// Returns all PCs for a given file/line. Useful for loops where the 'for' line +// could be split amongst 2 PCs. +func (dbl *DebugLineInfo) AllPCsForFileLine(f string, l int) (pcs []uint64) { + var ( + foundFile bool + lastAddr uint64 + sm = newStateMachine(dbl) + buf = bytes.NewBuffer(dbl.Instructions) + ) + + for b, err := buf.ReadByte(); err == nil; b, err = buf.ReadByte() { + findAndExecOpcode(sm, buf, b) + if foundFile && sm.file != f { + return + } + if sm.line == l && sm.file == f && sm.address != lastAddr { + foundFile = true + pcs = append(pcs, sm.address) + line := sm.line + // Keep going until we're on a different line. We only care about + // when a line comes back around (i.e. for loop) so get to next line, + // and try to find the line we care about again. + for b, err := buf.ReadByte(); err == nil; b, err = buf.ReadByte() { + findAndExecOpcode(sm, buf, b) + if line < sm.line { + break + } + } + } + } + return +} + +func findAndExecOpcode(sm *StateMachine, buf *bytes.Buffer, b byte) { + switch { + case b == 0: + execExtendedOpcode(sm, b, buf) + case b < sm.dbl.Prologue.OpcodeBase: + execStandardOpcode(sm, b, buf) + default: + execSpecialOpcode(sm, b) + } +} + +func execSpecialOpcode(sm *StateMachine, instr byte) { + var ( + opcode = uint8(instr) + decoded = opcode - sm.dbl.Prologue.OpcodeBase + ) + + if sm.dbl.Prologue.InitialIsStmt == uint8(1) { + sm.isStmt = true + } + + sm.lastDelta = int(sm.dbl.Prologue.LineBase + int8(decoded%sm.dbl.Prologue.LineRange)) + sm.line += sm.lastDelta + sm.address += uint64(decoded / sm.dbl.Prologue.LineRange) + sm.basicBlock = false + sm.lastWasStandard = false +} + +func execExtendedOpcode(sm *StateMachine, instr byte, buf *bytes.Buffer) { + _, _ = util.DecodeULEB128(buf) + b, _ := buf.ReadByte() + fn, ok := extendedopcodes[b] + if !ok { + panic(fmt.Sprintf("Encountered unknown extended opcode %#v\n", b)) + } + sm.lastWasStandard = false + + fn(sm, buf) +} + +func execStandardOpcode(sm *StateMachine, instr byte, buf *bytes.Buffer) { + fn, ok := standardopcodes[instr] + if !ok { + panic(fmt.Sprintf("Encountered unknown standard opcode %#v\n", instr)) + } + sm.lastWasStandard = true + + fn(sm, buf) +} + +func copyfn(sm *StateMachine, buf *bytes.Buffer) { + sm.basicBlock = false +} + +func advancepc(sm *StateMachine, buf *bytes.Buffer) { + addr, _ := util.DecodeULEB128(buf) + sm.address += addr * uint64(sm.dbl.Prologue.MinInstrLength) +} + +func advanceline(sm *StateMachine, buf *bytes.Buffer) { + line, _ := util.DecodeSLEB128(buf) + sm.line += int(line) + sm.lastDelta = int(line) +} + +func setfile(sm *StateMachine, buf *bytes.Buffer) { + i, _ := util.DecodeULEB128(buf) + sm.file = sm.dbl.FileNames[i-1].Name +} + +func setcolumn(sm *StateMachine, buf *bytes.Buffer) { + c, _ := util.DecodeULEB128(buf) + sm.column = uint(c) +} + +func negatestmt(sm *StateMachine, buf *bytes.Buffer) { + sm.isStmt = !sm.isStmt +} + +func setbasicblock(sm *StateMachine, buf *bytes.Buffer) { + sm.basicBlock = true +} + +func constaddpc(sm *StateMachine, buf *bytes.Buffer) { + sm.address += (255 / uint64(sm.dbl.Prologue.LineRange)) +} + +func fixedadvancepc(sm *StateMachine, buf *bytes.Buffer) { + var operand uint16 + binary.Read(buf, binary.LittleEndian, &operand) + + sm.address += uint64(operand) +} + +func endsequence(sm *StateMachine, buf *bytes.Buffer) { + sm.endSeq = true +} + +func setaddress(sm *StateMachine, buf *bytes.Buffer) { + var addr uint64 + + binary.Read(buf, binary.LittleEndian, &addr) + + sm.address = addr +} + +func definefile(sm *StateMachine, buf *bytes.Buffer) { + var ( + _, _ = util.ParseString(buf) + _, _ = util.DecodeULEB128(buf) + _, _ = util.DecodeULEB128(buf) + _, _ = util.DecodeULEB128(buf) + ) + + // Don't do anything here yet. +} diff --git a/proctl/breakpoints.go b/proctl/breakpoints.go index a6ba9a59..e77cfb69 100644 --- a/proctl/breakpoints.go +++ b/proctl/breakpoints.go @@ -55,10 +55,8 @@ func (dbp *DebuggedProcess) BreakpointExists(addr uint64) bool { return true } } - if _, ok := dbp.BreakPoints[addr]; ok { - return true - } - return false + _, ok := dbp.BreakPoints[addr] + return ok } func (dbp *DebuggedProcess) newBreakpoint(fn, f string, l int, addr uint64, data []byte) *BreakPoint { @@ -108,29 +106,3 @@ func (dbp *DebuggedProcess) setBreakpoint(tid int, addr uint64) (*BreakPoint, er dbp.BreakPoints[addr] = dbp.newBreakpoint(fn.Name, f, l, addr, originalData) return dbp.BreakPoints[addr], nil } - -func (dbp *DebuggedProcess) clearBreakpoint(tid int, addr uint64) (*BreakPoint, error) { - // Check for hardware breakpoint - for i, bp := range dbp.HWBreakPoints { - if bp == nil { - continue - } - if bp.Addr == addr { - dbp.HWBreakPoints[i] = nil - if err := clearHardwareBreakpoint(i, tid); err != nil { - return nil, err - } - return bp, nil - } - } - // Check for software breakpoint - if bp, ok := dbp.BreakPoints[addr]; ok { - thread := dbp.Threads[tid] - if _, err := writeMemory(thread, uintptr(bp.Addr), bp.OriginalData); err != nil { - return nil, fmt.Errorf("could not clear breakpoint %s", err) - } - delete(dbp.BreakPoints, addr) - return bp, nil - } - return nil, fmt.Errorf("No breakpoint currently set for %#v", addr) -} diff --git a/proctl/proctl.go b/proctl/proctl.go index 76fe3ac9..28abb718 100644 --- a/proctl/proctl.go +++ b/proctl/proctl.go @@ -11,12 +11,15 @@ import ( "path/filepath" "strconv" "strings" + "sync" "syscall" sys "golang.org/x/sys/unix" "github.com/derekparker/delve/dwarf/frame" + "github.com/derekparker/delve/dwarf/line" "github.com/derekparker/delve/dwarf/reader" + "github.com/derekparker/delve/source" ) // Struct representing a debugged process. Holds onto pid, register values, @@ -27,11 +30,13 @@ type DebuggedProcess struct { Dwarf *dwarf.Data GoSymTable *gosym.Table FrameEntries frame.FrameDescriptionEntries + LineInfo *line.DebugLineInfo HWBreakPoints [4]*BreakPoint BreakPoints map[uint64]*BreakPoint Threads map[int]*ThreadContext CurrentThread *ThreadContext os *OSProcessDetails + ast *source.Searcher breakpointIDCounter int running bool halt bool @@ -100,6 +105,28 @@ func (dbp *DebuggedProcess) Running() bool { return dbp.running } +// Finds the executable and then uses it +// to parse the following information: +// * Dwarf .debug_frame section +// * Dwarf .debug_line section +// * Go symbol table. +func (dbp *DebuggedProcess) LoadInformation() error { + var wg sync.WaitGroup + + exe, err := dbp.findExecutable() + if err != nil { + return err + } + + wg.Add(3) + go dbp.parseDebugFrame(exe, &wg) + go dbp.obtainGoSymbols(exe, &wg) + go dbp.parseDebugLineInfo(exe, &wg) + wg.Wait() + + return nil +} + // Find a location by string (file+line, function, breakpoint id, addr) func (dbp *DebuggedProcess) FindLocation(str string) (uint64, error) { // File + Line @@ -188,7 +215,30 @@ func (dbp *DebuggedProcess) BreakByLocation(loc string) (*BreakPoint, error) { // Clears a breakpoint in the current thread. func (dbp *DebuggedProcess) Clear(addr uint64) (*BreakPoint, error) { - return dbp.clearBreakpoint(dbp.CurrentThread.Id, addr) + tid := dbp.CurrentThread.Id + // Check for hardware breakpoint + for i, bp := range dbp.HWBreakPoints { + if bp == nil { + continue + } + if bp.Addr == addr { + dbp.HWBreakPoints[i] = nil + if err := clearHardwareBreakpoint(i, tid); err != nil { + return nil, err + } + return bp, nil + } + } + // Check for software breakpoint + if bp, ok := dbp.BreakPoints[addr]; ok { + thread := dbp.Threads[tid] + if _, err := writeMemory(thread, uintptr(bp.Addr), bp.OriginalData); err != nil { + return nil, fmt.Errorf("could not clear breakpoint %s", err) + } + delete(dbp.BreakPoints, addr) + return bp, nil + } + return nil, fmt.Errorf("no breakpoint at %#v", addr) } // Clears a breakpoint by location (function, file+line, address, breakpoint id) @@ -207,32 +257,59 @@ func (dbp *DebuggedProcess) Status() *sys.WaitStatus { // Step over function calls. func (dbp *DebuggedProcess) Next() error { - var runnable []*ThreadContext + return dbp.run(dbp.next) +} - fn := func() error { - for _, th := range dbp.Threads { - // Continue any blocked M so that the - // scheduler can continue to do its' - // job correctly. - if th.blocked() { - err := th.Continue() - if err != nil { - return err - } - continue - } - - runnable = append(runnable, th) - } - for _, th := range runnable { - err := th.Next() - if err != nil && err != sys.ESRCH { +func (dbp *DebuggedProcess) next() error { + curg, err := dbp.CurrentThread.curG() + if err != nil { + return err + } + defer dbp.clearTempBreakpoints() + for _, th := range dbp.Threads { + if th.blocked() { // Continue threads that aren't running go code. + if err := th.Continue(); err != nil { return err } + continue + } + if err := th.Next(); err != nil { + return err } - return dbp.Halt() } - return dbp.run(fn) + + for { + tid, err := trapWait(dbp, -1) + if err != nil { + return err + } + th, ok := dbp.Threads[tid] + if !ok { + return fmt.Errorf("unknown thread %d", tid) + } + pc, err := th.CurrentPC() + if err != nil { + return err + } + // Check if we've hit a software breakpoint. If so, reset PC. + if err = th.clearTempBreakpoint(pc - 1); err != nil { + return err + } + // Grab the current goroutine for this thread. + tg, err := th.curG() + if err != nil { + return err + } + // Make sure we're on the same goroutine. + // TODO(dp) take into account goroutine exit. + if tg.id == curg.id { + if dbp.CurrentThread.Id != tid { + dbp.SwitchThread(tid) + } + break + } + } + return dbp.Halt() } // Resume process. @@ -256,8 +333,7 @@ func (dbp *DebuggedProcess) Continue() error { } if wpid != dbp.CurrentThread.Id { - fmt.Printf("thread context changed from %d to %d\n", dbp.CurrentThread.Id, thread.Id) - dbp.CurrentThread = thread + dbp.SwitchThread(wpid) } pc, err := thread.CurrentPC() @@ -323,6 +399,7 @@ func (dbp *DebuggedProcess) Step() (err error) { func (dbp *DebuggedProcess) SwitchThread(tid int) error { if th, ok := dbp.Threads[tid]; ok { dbp.CurrentThread = th + fmt.Printf("thread context changed from %d to %d\n", dbp.CurrentThread.Id, tid) return nil } return fmt.Errorf("thread %d does not exist", tid) @@ -344,11 +421,28 @@ func (dbp *DebuggedProcess) EvalSymbol(name string) (*Variable, error) { return dbp.CurrentThread.EvalSymbol(name) } +func (dbp *DebuggedProcess) CallFn(name string, fn func(*ThreadContext) error) error { + return dbp.CurrentThread.CallFn(name, fn) +} + // Returns a reader for the dwarf data func (dbp *DebuggedProcess) DwarfReader() *reader.Reader { return reader.New(dbp.Dwarf) } +// Finds the breakpoint for the given pc. +func (dbp *DebuggedProcess) FindBreakpoint(pc uint64) (*BreakPoint, bool) { + for _, bp := range dbp.HWBreakPoints { + if bp != nil && bp.Addr == pc { + return bp, true + } + } + if bp, ok := dbp.BreakPoints[pc]; ok { + return bp, true + } + return nil, false +} + // Returns a new DebuggedProcess struct. func newDebugProcess(pid int, attach bool) (*DebuggedProcess, error) { dbp := DebuggedProcess{ @@ -356,6 +450,7 @@ func newDebugProcess(pid int, attach bool) (*DebuggedProcess, error) { Threads: make(map[int]*ThreadContext), BreakPoints: make(map[uint64]*BreakPoint), os: new(OSProcessDetails), + ast: source.New(), } if attach { @@ -386,6 +481,24 @@ func newDebugProcess(pid int, attach bool) (*DebuggedProcess, error) { return &dbp, nil } +func (dbp *DebuggedProcess) clearTempBreakpoints() error { + for _, bp := range dbp.HWBreakPoints { + if bp != nil && bp.Temp { + if _, err := dbp.Clear(bp.Addr); err != nil { + return err + } + } + } + for _, bp := range dbp.BreakPoints { + if !bp.Temp { + continue + } + if _, err := dbp.Clear(bp.Addr); err != nil { + return err + } + } + return nil +} func (dbp *DebuggedProcess) run(fn func() error) error { if dbp.exited { diff --git a/proctl/proctl_darwin.c b/proctl/proctl_darwin.c index 0c6e129c..0f4456a2 100644 --- a/proctl/proctl_darwin.c +++ b/proctl/proctl_darwin.c @@ -129,7 +129,6 @@ mach_port_wait(mach_port_t port_set) { mach_msg_port_descriptor_t *desc = (mach_msg_port_descriptor_t *)(bod + 1); thread = desc[0].name; - switch (msg.hdr.msgh_id) { case 2401: // Exception kret = thread_suspend(thread); diff --git a/proctl/proctl_darwin.go b/proctl/proctl_darwin.go index bf48ec68..5221e4c1 100644 --- a/proctl/proctl_darwin.go +++ b/proctl/proctl_darwin.go @@ -11,6 +11,7 @@ import ( "unsafe" "github.com/derekparker/delve/dwarf/frame" + "github.com/derekparker/delve/dwarf/line" sys "golang.org/x/sys/unix" ) @@ -31,40 +32,6 @@ func (dbp *DebuggedProcess) Halt() (err error) { return nil } -// Finds the executable and then uses it -// to parse the following information: -// * Dwarf .debug_frame section -// * Dwarf .debug_line section -// * Go symbol table. -func (dbp *DebuggedProcess) LoadInformation() error { - var ( - wg sync.WaitGroup - exe *macho.File - err error - ) - - ret := C.acquire_mach_task(C.int(dbp.Pid), &dbp.os.task, &dbp.os.portSet, &dbp.os.exceptionPort, &dbp.os.notificationPort) - if ret != C.KERN_SUCCESS { - return fmt.Errorf("could not acquire mach task %d", ret) - } - exe, err = dbp.findExecutable() - if err != nil { - return err - } - data, err := exe.DWARF() - if err != nil { - return err - } - dbp.Dwarf = data - - wg.Add(2) - go dbp.parseDebugFrame(exe, &wg) - go dbp.obtainGoSymbols(exe, &wg) - wg.Wait() - - return nil -} - func (dbp *DebuggedProcess) updateThreadList() error { var ( err error @@ -167,12 +134,41 @@ func (dbp *DebuggedProcess) obtainGoSymbols(exe *macho.File, wg *sync.WaitGroup) dbp.GoSymTable = tab } +func (dbp *DebuggedProcess) parseDebugLineInfo(exe *macho.File, wg *sync.WaitGroup) { + defer wg.Done() + + if sec := exe.Section("__debug_line"); sec != nil { + debugLine, err := exe.Section("__debug_line").Data() + if err != nil { + fmt.Println("could not get __debug_line section", err) + os.Exit(1) + } + dbp.LineInfo = line.Parse(debugLine) + } else { + fmt.Println("could not find __debug_line section in binary") + os.Exit(1) + } +} + func (dbp *DebuggedProcess) findExecutable() (*macho.File, error) { + ret := C.acquire_mach_task(C.int(dbp.Pid), &dbp.os.task, &dbp.os.portSet, &dbp.os.exceptionPort, &dbp.os.notificationPort) + if ret != C.KERN_SUCCESS { + return nil, fmt.Errorf("could not acquire mach task %d", ret) + } pathptr, err := C.find_executable(C.int(dbp.Pid)) if err != nil { return nil, err } - return macho.Open(C.GoString(pathptr)) + exe, err := macho.Open(C.GoString(pathptr)) + if err != nil { + return nil, err + } + data, err := exe.DWARF() + if err != nil { + return nil, err + } + dbp.Dwarf = data + return exe, nil } func trapWait(dbp *DebuggedProcess, pid int) (int, error) { diff --git a/proctl/proctl_linux.go b/proctl/proctl_linux.go index c4281d1f..c5b2ce5f 100644 --- a/proctl/proctl_linux.go +++ b/proctl/proctl_linux.go @@ -13,6 +13,7 @@ import ( sys "golang.org/x/sys/unix" "github.com/derekparker/delve/dwarf/frame" + "github.com/derekparker/delve/dwarf/line" ) const ( @@ -34,31 +35,6 @@ func (dbp *DebuggedProcess) Halt() (err error) { return nil } -// Finds the executable from /proc//exe and then -// uses that to parse the following information: -// * Dwarf .debug_frame section -// * Dwarf .debug_line section -// * Go symbol table. -func (dbp *DebuggedProcess) LoadInformation() error { - var ( - wg sync.WaitGroup - exe *elf.File - err error - ) - - exe, err = dbp.findExecutable() - if err != nil { - return err - } - - wg.Add(2) - go dbp.parseDebugFrame(exe, &wg) - go dbp.obtainGoSymbols(exe, &wg) - wg.Wait() - - return nil -} - // Attach to a newly created thread, and store that thread in our list of // known threads. func (dbp *DebuggedProcess) addThread(tid int, attach bool) (*ThreadContext, error) { @@ -103,6 +79,7 @@ func (dbp *DebuggedProcess) addThread(tid int, attach bool) (*ThreadContext, err dbp.Threads[tid] = &ThreadContext{ Id: tid, Process: dbp, + os: new(OSSpecificDetails), } if dbp.CurrentThread == nil { @@ -204,6 +181,22 @@ func (dbp *DebuggedProcess) obtainGoSymbols(exe *elf.File, wg *sync.WaitGroup) { dbp.GoSymTable = tab } +func (dbp *DebuggedProcess) parseDebugLineInfo(exe *elf.File, wg *sync.WaitGroup) { + defer wg.Done() + + if sec := exe.Section(".debug_line"); sec != nil { + debugLine, err := exe.Section(".debug_line").Data() + if err != nil { + fmt.Println("could not get .debug_line section", err) + os.Exit(1) + } + dbp.LineInfo = line.Parse(debugLine) + } else { + fmt.Println("could not find .debug_line section in binary") + os.Exit(1) + } +} + func stopped(pid int) bool { f, err := os.Open(fmt.Sprintf("/proc/%d/stat", pid)) if err != nil { diff --git a/proctl/proctl_test.go b/proctl/proctl_test.go index 159a8e31..0f488e22 100644 --- a/proctl/proctl_test.go +++ b/proctl/proctl_test.go @@ -216,8 +216,7 @@ func TestNext(t *testing.T) { {24, 26}, {26, 27}, {27, 34}, - {34, 35}, - {35, 41}, + {34, 41}, {41, 40}, {40, 41}, } @@ -232,18 +231,19 @@ func TestNext(t *testing.T) { _, err := p.Break(pc) assertNoError(err, t, "Break()") assertNoError(p.Continue(), t, "Continue()") + p.Clear(pc) f, ln := currentLineNumber(p, t) for _, tc := range testcases { if ln != tc.begin { - t.Fatalf("Program not stopped at correct spot expected %d was %s:%d", tc.begin, f, ln) + t.Fatalf("Program not stopped at correct spot expected %d was %s:%d", tc.begin, filepath.Base(f), ln) } assertNoError(p.Next(), t, "Next() returned an error") f, ln = currentLineNumber(p, t) if ln != tc.end { - t.Fatalf("Program did not continue to correct next location expected %d was %s:%d", tc.end, f, ln) + t.Fatalf("Program did not continue to correct next location expected %d was %s:%d", tc.end, filepath.Base(f), ln) } } @@ -355,3 +355,60 @@ func TestSwitchThread(t *testing.T) { } }) } + +func TestFunctionCall(t *testing.T) { + var testfile, _ = filepath.Abs("../_fixtures/testprog") + + withTestProcess(testfile, t, func(p *DebuggedProcess) { + pc, err := p.FindLocation("main.main") + if err != nil { + t.Fatal(err) + } + _, err = p.Break(pc) + if err != nil { + t.Fatal(err) + } + err = p.Continue() + if err != nil { + t.Fatal(err) + } + pc, err = p.CurrentPC() + if err != nil { + t.Fatal(err) + } + fn := p.GoSymTable.PCToFunc(pc) + if fn == nil { + t.Fatalf("Could not find func for PC: %#v", pc) + } + if fn.Name != "main.main" { + t.Fatal("Program stopped at incorrect place") + } + if err = p.CallFn("runtime.getg", func(th *ThreadContext) error { + pc, err := th.CurrentPC() + if err != nil { + t.Fatal(err) + } + f := th.Process.GoSymTable.LookupFunc("runtime.getg") + if f == nil { + t.Fatalf("could not find function %s", "runtime.getg") + } + if pc-1 != f.End-2 && pc != f.End-2 { + t.Fatalf("wrong pc expected %#v got %#v", f.End-2, pc-1) + } + return nil + }); err != nil { + t.Fatal(err) + } + pc, err = p.CurrentPC() + if err != nil { + t.Fatal(err) + } + fn = p.GoSymTable.PCToFunc(pc) + if fn == nil { + t.Fatalf("Could not find func for PC: %#v", pc) + } + if fn.Name != "main.main" { + t.Fatal("Program stopped at incorrect place") + } + }) +} diff --git a/proctl/threads.go b/proctl/threads.go index f3593b7c..a5d46860 100644 --- a/proctl/threads.go +++ b/proctl/threads.go @@ -5,8 +5,6 @@ import ( "fmt" sys "golang.org/x/sys/unix" - - "github.com/derekparker/delve/dwarf/frame" ) // ThreadContext represents a single thread in the traced process @@ -54,14 +52,14 @@ func (thread *ThreadContext) CurrentPC() (uint64, error) { // we step over any breakpoints. It will restore the instruction, // step, and then restore the breakpoint and continue. func (thread *ThreadContext) Continue() error { - regs, err := thread.Registers() + pc, err := thread.CurrentPC() if err != nil { return err } // Check whether we are stopped at a breakpoint, and // if so, single step over it before continuing. - if _, ok := thread.Process.BreakPoints[regs.PC()-1]; ok { + if _, ok := thread.Process.BreakPoints[pc-1]; ok { err := thread.Step() if err != nil { return fmt.Errorf("could not step %s", err) @@ -74,12 +72,12 @@ func (thread *ThreadContext) Continue() error { // Single steps this thread a single instruction, ensuring that // we correctly handle the likely case that we are at a breakpoint. func (thread *ThreadContext) Step() (err error) { - regs, err := thread.Registers() + pc, err := thread.CurrentPC() if err != nil { return err } - bp, ok := thread.Process.BreakPoints[regs.PC()-1] + bp, ok := thread.Process.BreakPoints[pc-1] if ok { // Clear the breakpoint so that we can continue execution. _, err = thread.Process.Clear(bp.Addr) @@ -88,14 +86,16 @@ func (thread *ThreadContext) Step() (err error) { } // Reset program counter to our restored instruction. - err = regs.SetPC(thread, bp.Addr) + err = thread.SetPC(bp.Addr) if err != nil { return fmt.Errorf("could not set registers %s", err) } // Restore breakpoint now that we have passed it. defer func() { - _, err = thread.Process.Break(bp.Addr) + var nbp *BreakPoint + nbp, err = thread.Process.Break(bp.Addr) + nbp.Temp = bp.Temp }() } @@ -107,96 +107,101 @@ func (thread *ThreadContext) Step() (err error) { return err } -// Step to next source line. Next will step over functions, -// and will follow through to the return address of a function. -// Next is implemented on the thread context, however during the -// course of this function running, it's very likely that the -// goroutine our M is executing will switch to another M, therefore -// this function cannot assume all execution will happen on this thread -// in the traced process. -func (thread *ThreadContext) Next() (err error) { - pc, err := thread.CurrentPC() +// Call a function named `name`. This is currently _NOT_ safe. +func (thread *ThreadContext) CallFn(name string, fn func(*ThreadContext) error) error { + f := thread.Process.GoSymTable.LookupFunc(name) + if f == nil { + return fmt.Errorf("could not find function %s", name) + } + + // Set breakpoint at the end of the function (before it returns). + bp, err := thread.Process.Break(f.End - 2) if err != nil { return err } + defer thread.Process.Clear(bp.Addr) - if bp, ok := thread.Process.BreakPoints[pc-1]; ok { - pc = bp.Addr - } - - fde, err := thread.Process.FrameEntries.FDEForPC(pc) - if err != nil { + if err := thread.saveRegisters(); err != nil { return err } - - _, l, _ := thread.Process.GoSymTable.PCToLine(pc) - ret := thread.ReturnAddressFromOffset(fde.ReturnAddressOffset(pc)) - for { - if err = thread.Step(); err != nil { - return err - } - - if pc, err = thread.CurrentPC(); err != nil { - return err - } - - if !fde.Cover(pc) && pc != ret { - if err := thread.continueToReturnAddress(pc, fde); err != nil { - if _, ok := err.(InvalidAddressError); !ok { - return err - } - } - if pc, err = thread.CurrentPC(); err != nil { - return err - } - } - - if _, nl, _ := thread.Process.GoSymTable.PCToLine(pc); nl != l { - break - } + if err = thread.SetPC(f.Entry); err != nil { + return err } - - return nil + defer thread.restoreRegisters() + if err := thread.Continue(); err != nil { + return err + } + if _, err = trapWait(thread.Process, -1); err != nil { + return err + } + return fn(thread) } -func (thread *ThreadContext) continueToReturnAddress(pc uint64, fde *frame.FrameDescriptionEntry) error { - for !fde.Cover(pc) { - // Offset is 0 because we have just stepped into this function. - addr := thread.ReturnAddressFromOffset(0) - bp, err := thread.Process.Break(addr) - if err != nil { - if _, ok := err.(BreakPointExistsError); !ok { - return err - } - } - bp.Temp = true - // Ensure we cleanup after ourselves no matter what. - defer thread.clearTempBreakpoint(bp.Addr) - - for { - err = thread.Continue() - if err != nil { - return err - } - // Wait on -1, just in case scheduler switches threads for this G. - wpid, err := trapWait(thread.Process, -1) - if err != nil { - return err - } - if wpid != thread.Id { - thread = thread.Process.Threads[wpid] - } - pc, err = thread.CurrentPC() - if err != nil { - return err - } - if (pc-1) == bp.Addr || pc == bp.Addr { - break - } - } +// Step to next source line. +// +// Next will step over functions, and will follow through to the +// return address of a function. +// +// This functionality is implemented by finding all possible next lines +// and setting a breakpoint at them. Once we've set a breakpoint at each +// potential line, we continue the thread. +func (thread *ThreadContext) Next() (err error) { + curpc, err := thread.CurrentPC() + if err != nil { + return err } - return nil + // Check and see if we're at a breakpoint, if so + // correct the PC value for the breakpoint instruction. + if bp, ok := thread.Process.BreakPoints[curpc-1]; ok { + curpc = bp.Addr + } + + // Grab info on our current stack frame. Used to determine + // whether we may be stepping outside of the current function. + fde, err := thread.Process.FrameEntries.FDEForPC(curpc) + if err != nil { + return err + } + + // Get current file/line. + f, l, _ := thread.Process.GoSymTable.PCToLine(curpc) + + // Find any line we could potentially get to. + lines, err := thread.Process.ast.NextLines(f, l) + if err != nil { + return err + } + + // Set a breakpoint at every line reachable from our location. + for _, l := range lines { + pcs := thread.Process.LineInfo.AllPCsForFileLine(f, l) + for _, pc := range pcs { + if pc == curpc { + continue + } + if !fde.Cover(pc) { + pc = thread.ReturnAddressFromOffset(fde.ReturnAddressOffset(pc)) + } + bp, err := thread.Process.Break(pc) + if err != nil { + if err, ok := err.(BreakPointExistsError); !ok { + return err + } + continue + } + bp.Temp = true + } + } + return thread.Continue() +} + +func (thread *ThreadContext) SetPC(pc uint64) error { + regs, err := thread.Registers() + if err != nil { + return err + } + return regs.SetPC(thread, pc) } // Takes an offset from RSP and returns the address of the @@ -214,22 +219,33 @@ func (thread *ThreadContext) ReturnAddressFromOffset(offset int64) uint64 { } func (thread *ThreadContext) clearTempBreakpoint(pc uint64) error { - var software bool - if _, ok := thread.Process.BreakPoints[pc]; ok { - software = true + clearbp := func(bp *BreakPoint) error { + if _, err := thread.Process.Clear(bp.Addr); err != nil { + return err + } + return thread.SetPC(bp.Addr) } - if _, err := thread.Process.Clear(pc); err != nil { - return err + for _, bp := range thread.Process.HWBreakPoints { + if bp != nil && bp.Temp && bp.Addr == pc { + return clearbp(bp) + } } - if software { - // Reset program counter to our restored instruction. - regs, err := thread.Registers() + if bp, ok := thread.Process.BreakPoints[pc]; ok && bp.Temp { + return clearbp(bp) + } + return nil +} + +func (thread *ThreadContext) curG() (*G, error) { + var g *G + err := thread.CallFn("runtime.getg", func(t *ThreadContext) error { + regs, err := t.Registers() if err != nil { return err } - - return regs.SetPC(thread, pc) - } - - return nil + reader := t.Process.Dwarf.Reader() + g, err = parseG(t.Process, regs.SP()+uint64(ptrsize), reader) + return err + }) + return g, err } diff --git a/proctl/threads_darwin.c b/proctl/threads_darwin.c index 28afecda..4ae4bfe5 100644 --- a/proctl/threads_darwin.c +++ b/proctl/threads_darwin.c @@ -54,6 +54,12 @@ get_registers(mach_port_name_t task, x86_thread_state64_t *state) { return thread_get_state(task, x86_THREAD_STATE64, (thread_state_t)state, &stateCount); } +kern_return_t +set_registers(mach_port_name_t task, x86_thread_state64_t *state) { + mach_msg_type_number_t stateCount = x86_THREAD_STATE64_COUNT; + return thread_set_state(task, x86_THREAD_STATE64, (thread_state_t)state, stateCount); +} + kern_return_t set_pc(thread_act_t task, uint64_t pc) { kern_return_t kret; @@ -101,7 +107,6 @@ resume_thread(thread_act_t thread) { kret = thread_resume(thread); if (kret != KERN_SUCCESS) return kret; } - return KERN_SUCCESS; } diff --git a/proctl/threads_darwin.go b/proctl/threads_darwin.go index 77333399..f8aa2732 100644 --- a/proctl/threads_darwin.go +++ b/proctl/threads_darwin.go @@ -9,6 +9,7 @@ import ( type OSSpecificDetails struct { thread_act C.thread_act_t + registers C.x86_thread_state64_t } func (t *ThreadContext) Halt() error { @@ -81,3 +82,19 @@ func readMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { } return len(data), nil } + +func (thread *ThreadContext) saveRegisters() error { + kret := C.get_registers(C.mach_port_name_t(thread.os.thread_act), &thread.os.registers) + if kret != C.KERN_SUCCESS { + return fmt.Errorf("could not save register contents") + } + return nil +} + +func (thread *ThreadContext) restoreRegisters() error { + kret := C.set_registers(C.mach_port_name_t(thread.os.thread_act), &thread.os.registers) + if kret != C.KERN_SUCCESS { + return fmt.Errorf("could not save register contents") + } + return nil +} diff --git a/proctl/threads_darwin.h b/proctl/threads_darwin.h index e03d6ce7..eb4d6b44 100644 --- a/proctl/threads_darwin.h +++ b/proctl/threads_darwin.h @@ -23,3 +23,6 @@ clear_trap_flag(thread_act_t); kern_return_t resume_thread(thread_act_t); + +kern_return_t +set_registers(mach_port_name_t, x86_thread_state64_t*); diff --git a/proctl/threads_linux.go b/proctl/threads_linux.go index 29ea5852..4e0d045e 100644 --- a/proctl/threads_linux.go +++ b/proctl/threads_linux.go @@ -8,7 +8,9 @@ import ( // Not actually used, but necessary // to be defined. -type OSSpecificDetails interface{} +type OSSpecificDetails struct { + registers sys.PtraceRegs +} func (t *ThreadContext) Halt() error { if stopped(t.Id) { @@ -55,3 +57,17 @@ func writeMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) func readMemory(thread *ThreadContext, addr uintptr, data []byte) (int, error) { return sys.PtracePeekData(thread.Id, addr, data) } + +func (thread *ThreadContext) saveRegisters() error { + var regs sys.PtraceRegs + err := sys.PtraceGetRegs(thread.Id, ®s) + if err != nil { + return err + } + thread.os.registers = regs + return nil +} + +func (thread *ThreadContext) restoreRegisters() error { + return sys.PtraceSetRegs(thread.Id, &thread.os.registers) +} diff --git a/proctl/variables.go b/proctl/variables.go index a5b60216..a456de5f 100644 --- a/proctl/variables.go +++ b/proctl/variables.go @@ -31,6 +31,11 @@ type M struct { curg uintptr } +type G struct { + id int + pc uint64 +} + const ptrsize uintptr = unsafe.Sizeof(int(1)) // Parses and returns select info on the internal M @@ -217,49 +222,48 @@ func (dbp *DebuggedProcess) PrintGoroutinesInfo() error { allg := binary.LittleEndian.Uint64(faddr) for i := uint64(0); i < allglen; i++ { - err = printGoroutineInfo(dbp, allg+(i*uint64(ptrsize)), reader) + g, err := parseG(dbp, allg+(i*uint64(ptrsize)), reader) if err != nil { return err } + f, l, fn := dbp.GoSymTable.PCToLine(g.pc) + fname := "" + if fn != nil { + fname = fn.Name + } + fmt.Printf("Goroutine %d - %s:%d %s\n", g.id, f, l, fname) } - return nil } -func printGoroutineInfo(dbp *DebuggedProcess, addr uint64, reader *dwarf.Reader) error { +func parseG(dbp *DebuggedProcess, addr uint64, reader *dwarf.Reader) (*G, error) { gaddrbytes, err := dbp.CurrentThread.readMemory(uintptr(addr), ptrsize) if err != nil { - return fmt.Errorf("error derefing *G %s", err) + return nil, fmt.Errorf("error derefing *G %s", err) } initialInstructions := append([]byte{op.DW_OP_addr}, gaddrbytes...) reader.Seek(0) goidaddr, err := offsetFor(dbp, "goid", reader, initialInstructions) if err != nil { - return err + return nil, err } reader.Seek(0) schedaddr, err := offsetFor(dbp, "sched", reader, initialInstructions) if err != nil { - return err + return nil, err } goidbytes, err := dbp.CurrentThread.readMemory(uintptr(goidaddr), ptrsize) if err != nil { - return fmt.Errorf("error reading goid %s", err) + return nil, fmt.Errorf("error reading goid %s", err) } schedbytes, err := dbp.CurrentThread.readMemory(uintptr(schedaddr+uint64(ptrsize)), ptrsize) if err != nil { - return fmt.Errorf("error reading sched %s", err) + return nil, fmt.Errorf("error reading sched %s", err) } gopc := binary.LittleEndian.Uint64(schedbytes) - f, l, fn := dbp.GoSymTable.PCToLine(gopc) - fname := "" - if fn != nil { - fname = fn.Name - } - fmt.Printf("Goroutine %d - %s:%d %s\n", binary.LittleEndian.Uint64(goidbytes), f, l, fname) - return nil + return &G{id: int(binary.LittleEndian.Uint64(goidbytes)), pc: gopc}, nil } func allglenval(dbp *DebuggedProcess, reader *dwarf.Reader) (uint64, error) { diff --git a/source/source.go b/source/source.go new file mode 100644 index 00000000..0f30de21 --- /dev/null +++ b/source/source.go @@ -0,0 +1,253 @@ +package source + +import ( + "go/ast" + "go/parser" + "go/token" +) + +type Searcher struct { + fileset *token.FileSet + visited map[string]*ast.File +} + +func New() *Searcher { + return &Searcher{fileset: token.NewFileSet(), visited: make(map[string]*ast.File)} +} + +// Returns the first node at the given file:line. +func (s *Searcher) FirstNodeAt(fname string, line int) (ast.Node, error) { + var node ast.Node + f, err := s.parse(fname) + if err != nil { + return nil, err + } + ast.Inspect(f, func(n ast.Node) bool { + if n == nil { + return true + } + position := s.fileset.Position(n.Pos()) + if position.Line == line { + node = n + return false + } + return true + }) + return node, nil +} + +type Done string + +func (d Done) Error() string { + return string(d) +} + +// Returns all possible lines that could be executed after the given file:line, +// within the same source file. +func (s *Searcher) NextLines(fname string, line int) (lines []int, err error) { + var found bool + n, err := s.FirstNodeAt(fname, line) + if err != nil { + return nil, err + } + defer func() { + if e := recover(); e != nil { + e = e.(Done) + nl := make([]int, 0, len(lines)) + fnd := make(map[int]bool) + for _, l := range lines { + if _, ok := fnd[l]; !ok { + fnd[l] = true + nl = append(nl, l) + } + } + lines = nl + } + }() + + switch x := n.(type) { + // Check if we are at an 'if' statement. + // + // If we are at an 'if' statement, employ the following algorithm: + // * Follow all 'else if' statements, appending their line number + // * Follow any 'else' statement if it exists, appending the line + // number of the statement following the 'else'. + // * If there is no 'else' statement, append line of first statement + // following the entire 'if' block. + case *ast.IfStmt: + var rbrace int + p := x.Body.List[0].Pos() + pos := s.fileset.Position(p) + lines = append(lines, pos.Line) + + if x.Else == nil { + // Grab first line after entire 'if' block + rbrace = s.fileset.Position(x.Body.Rbrace).Line + n, err := s.FirstNodeAt(fname, 1) + if err != nil { + return nil, err + } + ast.Inspect(n, func(n ast.Node) bool { + if n == nil { + return true + } + pos := s.fileset.Position(n.Pos()) + if rbrace < pos.Line { + lines = append(lines, pos.Line) + panic(Done("done")) + } + return true + }) + } else { + // Follow any 'else' statements + for { + if stmt, ok := x.Else.(*ast.IfStmt); ok { + pos := s.fileset.Position(stmt.Pos()) + lines = append(lines, pos.Line) + x = stmt + continue + } + pos := s.fileset.Position(x.Else.Pos()) + ast.Inspect(x, func(n ast.Node) bool { + if found { + panic(Done("done")) + } + if n == nil { + return false + } + p := s.fileset.Position(n.Pos()) + if pos.Line < p.Line { + lines = append(lines, p.Line) + found = true + return false + } + return true + }) + } + } + + // Follow case statements. + // + // Append line for first statement following each 'case' condition. + case *ast.SwitchStmt: + ast.Inspect(x, func(n ast.Node) bool { + if stmt, ok := n.(*ast.SwitchStmt); ok { + ast.Inspect(stmt, func(n ast.Node) bool { + if stmt, ok := n.(*ast.CaseClause); ok { + p := stmt.Body[0].Pos() + pos := s.fileset.Position(p) + lines = append(lines, pos.Line) + return false + } + return true + }) + panic(Done("done")) + } + return true + }) + // Default case - find next source line. + // + // We are not at a branch, employ the following algorithm: + // * Traverse tree, storing any loop as a parent + // * Find next source line after the given line + // * Check and see if we've passed the scope of any parent we've + // stored. If so, pop them off the stack. The last parent that + // is left get's appending to our list of lines since we could + // end up at the top of the loop again. + default: + var ( + parents []*ast.BlockStmt + parentLines []int + parentLine int + ) + f, err := s.parse(fname) + if err != nil { + return nil, err + } + ast.Inspect(f, func(n ast.Node) bool { + if found { + panic(Done("done")) + } + if n == nil { + return true + } + if stmt, ok := n.(*ast.ForStmt); ok { + parents = append(parents, stmt.Body) + pos := s.fileset.Position(stmt.Pos()) + parentLine = pos.Line + parentLines = append(parentLines, pos.Line) + } + pos := s.fileset.Position(n.Pos()) + if line < pos.Line { + if _, ok := n.(*ast.BlockStmt); ok { + return true + } + for { + if 0 < len(parents) { + parent := parents[len(parents)-1] + endLine := s.fileset.Position(parent.Rbrace).Line + if endLine < line { + if len(parents) == 1 { + parents = []*ast.BlockStmt{} + parentLines = []int{} + parentLine = 0 + } else { + parents = parents[0 : len(parents)-1] + parentLines = parentLines[0:len(parents)] + parent = parents[len(parents)-1] + parentLine = s.fileset.Position(parent.Pos()).Line + } + continue + } + if parentLine != 0 { + var endfound bool + ast.Inspect(f, func(n ast.Node) bool { + if n == nil || endfound { + return false + } + if _, ok := n.(*ast.BlockStmt); ok { + return true + } + pos := s.fileset.Position(n.Pos()) + if endLine < pos.Line { + endLine = pos.Line + endfound = true + return false + } + return true + }) + lines = append(lines, parentLine, endLine) + } + } + break + } + if _, ok := n.(*ast.BranchStmt); !ok { + lines = append(lines, pos.Line) + } + found = true + return false + } + return true + }) + if len(lines) == 0 && 0 < len(parents) { + parent := parents[len(parents)-1] + lbrace := s.fileset.Position(parent.Lbrace).Line + pos := s.fileset.Position(parent.List[0].Pos()) + lines = append(lines, lbrace, pos.Line) + } + } + return lines, nil +} + +// Parses file named by fname, caching files it has already parsed. +func (s *Searcher) parse(fname string) (*ast.File, error) { + if f, ok := s.visited[fname]; ok { + return f, nil + } + f, err := parser.ParseFile(s.fileset, fname, nil, 0) + if err != nil { + return nil, err + } + s.visited[fname] = f + return f, nil +} diff --git a/source/source_test.go b/source/source_test.go new file mode 100644 index 00000000..aceb70ec --- /dev/null +++ b/source/source_test.go @@ -0,0 +1,57 @@ +package source + +import ( + "fmt" + "go/ast" + "path/filepath" + "testing" +) + +func TestTokenAtLine(t *testing.T) { + var ( + tf, _ = filepath.Abs("../_fixtures/testvisitorprog.go") + v = New() + ) + n, err := v.FirstNodeAt(tf, 8) + if err != nil { + t.Fatal(err) + } + if _, ok := n.(*ast.IfStmt); !ok { + t.Fatal("Did not get correct node") + } +} + +func TestNextLines(t *testing.T) { + var ( + tf, _ = filepath.Abs("../_fixtures/testvisitorprog.go") + v = New() + ) + cases := []struct { + line int + nextlines []int + }{ + {8, []int{9, 10, 13}}, + {15, []int{17, 19}}, + {25, []int{27}}, + {22, []int{6, 25}}, + {33, []int{36}}, + {36, []int{37, 40}}, + {47, []int{44, 51}}, + {57, []int{55, 56}}, + } + for i, c := range cases { + lines, err := v.NextLines(tf, c.line) + if err != nil { + t.Fatal(err) + } + if len(lines) != len(c.nextlines) { + fmt.Println(lines) + t.Fatalf("did not get correct number of lines back expected %d got %d for test case %d", len(c.nextlines), len(lines), i+1) + } + for i, l := range lines { + if l != c.nextlines[i] { + t.Fatalf("expected index %d to be %d got %d", i, c.nextlines[i], l) + } + } + } +}