diff --git a/main.go b/main.go index 76f4381..150481c 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "net" "os" "strings" "time" @@ -95,7 +96,14 @@ func main() { *dataOrigin = configData.DataOrigin } - ip := util.DomainLookUp(domain) + var ip net.IP + + if *tcpSYNFlag || *udpPackageFlag { + ip = util.DomainLookUp(domain, true) + } else { + ip = util.DomainLookUp(domain, false) + } + printer.PrintTraceRouteNav(ip, domain, *dataOrigin) var m trace.Method = "" diff --git a/printer/printer_test.go b/printer/printer_test.go index d285211..25460f4 100644 --- a/printer/printer_test.go +++ b/printer/printer_test.go @@ -2,16 +2,17 @@ package printer import ( "errors" - "github.com/xgadget-lab/nexttrace/ipgeo" - "github.com/xgadget-lab/nexttrace/trace" - "github.com/xgadget-lab/nexttrace/util" "net" "testing" "time" + + "github.com/xgadget-lab/nexttrace/ipgeo" + "github.com/xgadget-lab/nexttrace/trace" + "github.com/xgadget-lab/nexttrace/util" ) func TestPrintTraceRouteNav(t *testing.T) { - PrintTraceRouteNav(util.DomainLookUp("1.1.1.1"), "1.1.1.1", "dataOrigin") + PrintTraceRouteNav(util.DomainLookUp("1.1.1.1", false), "1.1.1.1", "dataOrigin") } var testGeo = &ipgeo.IPGeoData{ diff --git a/util/util.go b/util/util.go index e1f2bcf..b981b38 100644 --- a/util/util.go +++ b/util/util.go @@ -25,7 +25,7 @@ func LocalIPPort(dstip net.IP) (net.IP, int) { return nil, -1 } -func DomainLookUp(host string) net.IP { +func DomainLookUp(host string, ipv4Only bool) net.IP { ips, err := net.LookupIP(host) if err != nil { fmt.Println("Domain " + host + " Lookup Fail.") @@ -36,17 +36,20 @@ func DomainLookUp(host string) net.IP { var ipv6Flag = false for _, ip := range ips { - ipSlice = append(ipSlice, ip) - // 仅返回ipv4的ip - // if ip.To4() != nil { - // ipSlice = append(ipSlice, ip) - // } else { - // ipv6Flag = true - // } + if ipv4Only { + // 仅返回ipv4的ip + if ip.To4() != nil { + ipSlice = append(ipSlice, ip) + } else { + ipv6Flag = true + } + } else { + ipSlice = append(ipSlice, ip) + } } if ipv6Flag { - fmt.Println("[Info] IPv6 Traceroute is not supported right now.") + fmt.Println("[Info] IPv6 TCP/UDP Traceroute is not supported right now.") if len(ipSlice) == 0 { os.Exit(0) }