Add Traceroute Module

This commit is contained in:
sjlleo
2022-04-24 10:07:01 +08:00
parent 6f83beb899
commit 1a62f39a6e
12 changed files with 1198 additions and 0 deletions

10
go.mod Normal file
View File

@@ -0,0 +1,10 @@
module traceroute
go 1.18
require (
github.com/google/gopacket v1.1.19
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4
)
require golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005 // indirect

22
go.sum Normal file
View File

@@ -0,0 +1,22 @@
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4 h1:b0LrWgu8+q7z4J+0Y3Umo5q1dL7NXBkKBWkaVkAq17E=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005 h1:pDMpM2zh2MT0kHy037cKlSby2nEhD50SYqwQk76Nm40=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -0,0 +1,61 @@
package listener_channel
import (
"golang.org/x/net/context"
"net"
"time"
)
type ReceivedMessage struct {
N *int
Peer net.Addr
Msg []byte
Err error
}
type ListenerChannel struct {
ctx context.Context
cancel context.CancelFunc
Conn net.PacketConn
Messages chan ReceivedMessage
}
func New(conn net.PacketConn) *ListenerChannel {
ctx, cancel := context.WithCancel(context.Background())
results := make(chan ReceivedMessage, 50)
return &ListenerChannel{Conn: conn, ctx: ctx, cancel: cancel, Messages: results}
}
func (l *ListenerChannel) Start() {
for {
select {
case <-l.ctx.Done():
return
default:
}
reply := make([]byte, 1500)
err := l.Conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
l.Messages <- ReceivedMessage{Err: err}
continue
}
n, peer, err := l.Conn.ReadFrom(reply)
if err != nil {
l.Messages <- ReceivedMessage{Err: err}
continue
}
l.Messages <- ReceivedMessage{
N: &n,
Peer: peer,
Err: nil,
Msg: reply,
}
}
}
func (l *ListenerChannel) Stop() {
l.cancel()
}

211
main.go Normal file
View File

@@ -0,0 +1,211 @@
package main
import (
"traceroute/methods"
"traceroute/methods/tcp"
"traceroute/methods/udp"
"os"
"net"
"time"
"fmt"
"net/http"
"io/ioutil"
"encoding/json"
"flag"
"strings"
)
type IPGeoData struct {
Asnumber string `json:"asnumber"`
Country string `json:"country"`
Prov string `json:"prov"`
City string `json:"city"`
District string `json:"district"`
Owner string `json:"owner"`
Isp string `json:"isp"`
}
var tcpSYNFlag = flag.Bool("T", false, "Use TCP SYN for tracerouting (default port is 80 in TCP, 53 in UDP)")
var port = flag.Int("p", 80, "Set SYN Traceroute Port")
var numMeasurements = flag.Int("q", 3, "Set the number of probes per each hop.")
var parallelRequests = flag.Int("r", 18, "Set ParallelRequests number. It should be 1 when there is a multi-routing.")
var maxHops = flag.Int("m", 30, "Set the max number of hops (max TTL to be reached).")
func main() {
fmt.Println("ManGoTrace v0.0.1 Alpha \nOwO Organiztion Leo (leo.moe) & Vincent (vincent.moe)")
ip := domainLookUp(flagApply())
fmt.Printf("traceroute to %s, 30 hops max, 32 byte packets\n", ip.String())
if (*tcpSYNFlag) {
tcpTraceroute := tcp.New(ip, methods.TracerouteConfig{
MaxHops: uint16(*maxHops),
NumMeasurements: uint16(*numMeasurements),
ParallelRequests: uint16(*parallelRequests),
Port: *port,
Timeout: time.Second / 2,
})
res, _ := tcpTraceroute.Start()
traceroutePrinter(ip, res)
} else {
if (*port == 80) {
*port = 53
}
udpTraceroute := udp.New(ip, true, methods.TracerouteConfig{
MaxHops: uint16(*maxHops),
NumMeasurements: uint16(*numMeasurements),
ParallelRequests: uint16(*parallelRequests),
Port: *port,
Timeout: 2 * time.Second,
})
res, _ := udpTraceroute.Start()
traceroutePrinter(ip, res)
}
}
func traceroutePrinter(ip net.IP, res *map[uint16][]methods.TracerouteHop) {
hopIndex := uint16(1)
for ; hopIndex <= 29 ; {
for k,v := range *res {
if (k == hopIndex) {
fmt.Print(k)
for _,v2 := range v {
ch := make(chan uint16)
go hopPrinter(hopIndex, ip, v2, ch)
hopIndex = <- ch
}
hopIndex = hopIndex + 1
break
}
}
}
}
func flagApply() string{
flag.Parse()
ipArg := flag.Args()
if (flag.NArg() != 1) {
fmt.Println("Args Error")
os.Exit(2)
}
return ipArg[0]
}
func getIPGeo(ip string, c chan IPGeoData) {
resp, err := http.Get("https://leo.moe/api.php?ip=" + ip)
if err != nil {
fmt.Println(err)
}
defer resp.Body.Close()
body, _ := ioutil.ReadAll(resp.Body)
ipGeoData := IPGeoData{}
err = json.Unmarshal(body,&ipGeoData)
if err != nil {
fmt.Println(err)
}
c <- ipGeoData
}
func domainLookUp(host string) net.IP {
ips, err := net.LookupIP(host)
if (err != nil) {
fmt.Println("Domain Lookup Fail.")
os.Exit(1)
}
var ipSlice = []net.IP{}
for _, ip := range ips {
ipSlice = append(ipSlice, ip)
}
if (len(ipSlice) == 1) {
return ipSlice[0]
} else {
fmt.Println("Please Choose the IP You Want To TraceRoute")
for i, ip := range ipSlice {
fmt.Printf("%d. %s\n",i, ip)
}
var index int
fmt.Printf("Your Option: ")
fmt.Scanln(&index)
if (index >= len(ipSlice) || index < 0) {
fmt.Println("Your Option is invalid")
os.Exit(3)
}
return ipSlice[index]
}
}
func hopPrinter(hopIndex uint16, ip net.IP, v2 methods.TracerouteHop, c chan uint16) {
if (v2.Address == nil) {
fmt.Println("\t*")
} else {
ip_str := fmt.Sprintf("%s", v2.Address)
ptr, err := net.LookupAddr(ip_str)
ch_b := make(chan IPGeoData)
go getIPGeo(ip_str, ch_b)
iPGeoData := <-ch_b
if (ip.String() == ip_str) {
hopIndex = 30
iPGeoData.Owner = iPGeoData.Isp
}
if (strings.Index(ip_str, "9.31.") == 0 || strings.Index(ip_str, "11.72.") == 0) {
fmt.Printf("\t%-15s %.2fms * 局域网, 腾讯云\n", v2.Address, v2.RTT.Seconds()*1000)
c <- hopIndex
return
}
if (strings.Index(ip_str, "11.13.") == 0) {
fmt.Printf("\t%-15s %.2fms * 局域网, 阿里云\n", v2.Address, v2.RTT.Seconds()*1000)
c <- hopIndex
return
}
if (iPGeoData.Owner == "") {
iPGeoData.Owner = iPGeoData.Isp
}
if (iPGeoData.Asnumber == "") {
iPGeoData.Asnumber = "*"
} else {
iPGeoData.Asnumber = "AS" + iPGeoData.Asnumber
}
if (iPGeoData.District != "") {
iPGeoData.City = iPGeoData.City + ", " + iPGeoData.District
}
if (iPGeoData.Country == "") {
fmt.Printf("\t%-15s %.2fms * 局域网\n", v2.Address, v2.RTT.Seconds()*1000)
c <- hopIndex
return
}
if (iPGeoData.Prov == "" && iPGeoData.City == "") {
if err != nil {
fmt.Printf("\t%-15s %.2fms %s %s, %s, %s 骨干网\n",v2.Address, v2.RTT.Seconds()*1000, iPGeoData.Asnumber, iPGeoData.Country, iPGeoData.Owner, iPGeoData.Owner)
} else {
fmt.Printf("\t%-15s (%s) %.2fms %s %s, %s, %s 骨干网\n",ptr[0], v2.Address, v2.RTT.Seconds()*1000, iPGeoData.Asnumber, iPGeoData.Country, iPGeoData.Owner, iPGeoData.Owner)
}
} else {
if err != nil {
fmt.Printf("\t%-15s %.2fms %s %s, %s, %s, %s\n",v2.Address, v2.RTT.Seconds()*1000, iPGeoData.Asnumber, iPGeoData.Country, iPGeoData.Prov, iPGeoData.City, iPGeoData.Owner)
} else {
fmt.Printf("\t%-15s (%s) %.2fms %s %s, %s, %s, %s\n",ptr[0], v2.Address, v2.RTT.Seconds()*1000, iPGeoData.Asnumber, iPGeoData.Country, iPGeoData.Prov, iPGeoData.City, iPGeoData.Owner)
}
}
}
c <- hopIndex
}

79
methods/methods.go Normal file
View File

@@ -0,0 +1,79 @@
package methods
import (
"encoding/binary"
"errors"
"net"
"time"
)
// TracerouteHop type
type TracerouteHop struct {
Success bool
Address net.Addr
TTL uint16
RTT *time.Duration
}
type TracerouteConfig struct {
MaxHops uint16
NumMeasurements uint16
ParallelRequests uint16
Port int
Timeout time.Duration
}
func GetIPHeaderLength(data []byte) (int, error) {
if len(data) < 1 {
return 0, errors.New("received invalid IP header")
}
return int((data[0] & 0x0F) * 4), nil
}
func GetICMPResponsePayload(data []byte) ([]byte, error) {
length, err := GetIPHeaderLength(data)
if err != nil {
return nil, err
}
if len(data) < length {
return nil, errors.New("length of packet too short")
}
return data[length:], nil
}
func GetUDPSrcPort(data []byte) uint16 {
srcPortBytes := data[:2]
srcPort := binary.BigEndian.Uint16(srcPortBytes)
return srcPort
}
func GetTCPSeq(data []byte) uint32 {
seqBytes := data[4:8]
return binary.BigEndian.Uint32(seqBytes)
}
func ReduceFinalResult(preliminary map[uint16][]TracerouteHop, maxHops uint16, destIP net.IP) map[uint16][]TracerouteHop {
// reduce the results to remove all hops after the first encounter to final destination
finalResults := map[uint16][]TracerouteHop{}
for i := uint16(1); i < maxHops; i++ {
foundFinal := false
probes := preliminary[i]
if probes == nil {
break
}
finalResults[i] = []TracerouteHop{}
for _, probe := range probes {
if probe.Success && probe.Address.String() == destIP.String() {
foundFinal = true
}
finalResults[i] = append(finalResults[i], probe)
}
if foundFinal {
break
}
}
return finalResults
}

37
methods/quic/quic.go Normal file

File diff suppressed because one or more lines are too long

328
methods/tcp/tcp.go Normal file
View File

@@ -0,0 +1,328 @@
package tcp
import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"traceroute/listener_channel"
"traceroute/methods"
"traceroute/parallel_limiter"
"traceroute/signal"
"traceroute/util"
"golang.org/x/net/context"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"log"
"math"
"math/rand"
"net"
"sync"
"time"
)
type inflightData struct {
start time.Time
ttl uint16
}
type results struct {
inflightRequests sync.Map
results map[uint16][]methods.TracerouteHop
resultsMu sync.Mutex
err error
concurrentRequests *parallel_limiter.ParallelLimiter
reachedFinalHop *signal.Signal
}
type Traceroute struct {
opConfig opConfig
trcrtConfig methods.TracerouteConfig
results results
}
type opConfig struct {
icmpConn net.PacketConn
tcpConn net.PacketConn
tcpMu sync.Mutex
destIP net.IP
srcIP net.IP
wg *sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func New(destIP net.IP, config methods.TracerouteConfig) *Traceroute {
return &Traceroute{
opConfig: opConfig{
destIP: destIP,
},
trcrtConfig: config,
}
}
func (tr *Traceroute) Start() (*map[uint16][]methods.TracerouteHop, error) {
tr.opConfig.ctx, tr.opConfig.cancel = context.WithCancel(context.Background())
tr.opConfig.srcIP, _ = util.LocalIPPort(tr.opConfig.destIP)
var err error
tr.opConfig.tcpConn, err = net.ListenPacket("ip4:tcp", tr.opConfig.srcIP.String())
if err != nil {
return nil, err
}
tr.opConfig.icmpConn, err = icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
tr.opConfig.wg = &wg
tr.results = results{
inflightRequests: sync.Map{},
concurrentRequests: parallel_limiter.New(int(tr.trcrtConfig.ParallelRequests)),
reachedFinalHop: signal.New(),
results: map[uint16][]methods.TracerouteHop{},
}
return tr.start()
}
func (tr *Traceroute) timeoutLoop() {
ticker := time.NewTicker(tr.trcrtConfig.Timeout / 4)
go func() {
for range ticker.C {
tr.results.inflightRequests.Range(func(key, value interface{}) bool {
request := value.(inflightData)
expired := time.Since(request.start) > tr.trcrtConfig.Timeout
if !expired {
return true
}
tr.results.inflightRequests.Delete(key)
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: false,
TTL: request.ttl,
})
tr.results.concurrentRequests.Finished()
tr.opConfig.wg.Done()
return true
})
}
}()
select {
case <-tr.opConfig.ctx.Done():
ticker.Stop()
}
}
func (tr *Traceroute) addToResult(ttl uint16, hop methods.TracerouteHop) {
tr.results.resultsMu.Lock()
defer tr.results.resultsMu.Unlock()
if tr.results.results[ttl] == nil {
tr.results.results[ttl] = []methods.TracerouteHop{}
}
tr.results.results[ttl] = append(tr.results.results[ttl], hop)
}
func (tr *Traceroute) handleICMPMessage(msg listener_channel.ReceivedMessage, data []byte) {
header, err := methods.GetICMPResponsePayload(data)
if err != nil {
return
}
sequenceNumber := methods.GetTCPSeq(header)
val, ok := tr.results.inflightRequests.LoadAndDelete(sequenceNumber)
if !ok {
return
}
request := val.(inflightData)
elapsed := time.Since(request.start)
if msg.Peer.String() == tr.opConfig.destIP.String() {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: true,
Address: msg.Peer,
TTL: request.ttl,
RTT: &elapsed,
})
tr.results.concurrentRequests.Finished()
tr.opConfig.wg.Done()
}
func (tr *Traceroute) icmpListener() {
lc := listener_channel.New(tr.opConfig.icmpConn)
defer lc.Stop()
go lc.Start()
for {
select {
case <-tr.opConfig.ctx.Done():
return
case msg := <-lc.Messages:
if msg.N == nil {
continue
}
rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
if err != nil {
log.Println(err)
continue
}
switch rm.Type {
case ipv4.ICMPTypeTimeExceeded:
body := rm.Body.(*icmp.TimeExceeded).Data
tr.handleICMPMessage(msg, body)
case ipv4.ICMPTypeDestinationUnreachable:
body := rm.Body.(*icmp.DstUnreach).Data
tr.handleICMPMessage(msg, body)
default:
log.Println("received icmp message of unknown type")
}
}
}
}
func (tr *Traceroute) tcpListener() {
lc := listener_channel.New(tr.opConfig.tcpConn)
defer lc.Stop()
go lc.Start()
for {
select {
case <-tr.opConfig.ctx.Done():
return
case msg := <-lc.Messages:
if msg.N == nil {
continue
}
if msg.Peer.String() != tr.opConfig.destIP.String() {
continue
}
// Decode a packet
packet := gopacket.NewPacket(msg.Msg[:*msg.N], layers.LayerTypeTCP, gopacket.Default)
// Get the TCP layer from this packet
if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
tcp, _ := tcpLayer.(*layers.TCP)
val, ok := tr.results.inflightRequests.LoadAndDelete(tcp.Ack - 1)
if !ok {
continue
}
request := val.(inflightData)
tr.results.concurrentRequests.Finished()
elapsed := time.Since(request.start)
if msg.Peer.String() == tr.opConfig.destIP.String() {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: true,
Address: msg.Peer,
TTL: request.ttl,
RTT: &elapsed,
})
tr.opConfig.wg.Done()
}
}
}
}
func (tr *Traceroute) sendMessage(ttl uint16) {
_, srcPort := util.LocalIPPort(tr.opConfig.destIP)
ipHeader := &layers.IPv4{
SrcIP: tr.opConfig.srcIP,
DstIP: tr.opConfig.destIP,
Protocol: layers.IPProtocolTCP,
TTL: uint8(ttl),
}
sequenceNumber := uint32(rand.Intn(math.MaxUint32))
tcpHeader := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(tr.trcrtConfig.Port),
Seq: sequenceNumber,
SYN: true,
Window: 14600,
}
_ = tcpHeader.SetNetworkLayerForChecksum(ipHeader)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, tcpHeader); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
tr.opConfig.tcpMu.Lock()
defer tr.opConfig.tcpMu.Unlock()
err := ipv4.NewPacketConn(tr.opConfig.tcpConn).SetTTL(int(ttl))
if err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
start := time.Now()
if _, err := tr.opConfig.tcpConn.WriteTo(buf.Bytes(), &net.IPAddr{IP: tr.opConfig.destIP}); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
tr.results.inflightRequests.Store(sequenceNumber, inflightData{start: start, ttl: ttl})
}
func (tr *Traceroute) sendLoop() {
rand.Seed(time.Now().UTC().UnixNano())
defer tr.opConfig.wg.Done()
for ttl := uint16(1); ttl <= tr.trcrtConfig.MaxHops; ttl++ {
select {
case <-tr.results.reachedFinalHop.Chan():
return
default:
}
for i := 0; i < int(tr.trcrtConfig.NumMeasurements); i++ {
select {
case <-tr.opConfig.ctx.Done():
return
case <-tr.results.concurrentRequests.Start():
tr.opConfig.wg.Add(1)
go tr.sendMessage(ttl)
}
}
}
}
func (tr *Traceroute) start() (*map[uint16][]methods.TracerouteHop, error) {
go tr.timeoutLoop()
go tr.icmpListener()
go tr.tcpListener()
tr.opConfig.wg.Add(1)
go tr.sendLoop()
tr.opConfig.wg.Wait()
tr.opConfig.cancel()
if tr.results.err != nil {
return nil, tr.results.err
}
result := methods.ReduceFinalResult(tr.results.results, tr.trcrtConfig.MaxHops, tr.opConfig.destIP)
return &result, tr.results.err
}

310
methods/udp/udp.go Normal file
View File

@@ -0,0 +1,310 @@
package udp
import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"traceroute/listener_channel"
"traceroute/methods"
"traceroute/methods/quic"
"traceroute/parallel_limiter"
"traceroute/signal"
"traceroute/taskgroup"
"traceroute/util"
"golang.org/x/net/context"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"log"
"math/rand"
"net"
"sync"
time "time"
)
type inflightData struct {
icmpMsg chan<- net.Addr
}
type opConfig struct {
quic bool
destIP net.IP
wg *taskgroup.TaskGroup
icmpConn net.PacketConn
ctx context.Context
cancel context.CancelFunc
}
type results struct {
inflightRequests sync.Map
results map[uint16][]methods.TracerouteHop
resultsMu sync.Mutex
err error
concurrentRequests *parallel_limiter.ParallelLimiter
reachedFinalHop *signal.Signal
}
type Traceroute struct {
trcrtConfig methods.TracerouteConfig
opConfig opConfig
results results
}
func New(destIP net.IP, quic bool, config methods.TracerouteConfig) *Traceroute {
return &Traceroute{
opConfig: opConfig{
quic: quic,
destIP: destIP,
},
trcrtConfig: config,
}
}
func (tr *Traceroute) Start() (*map[uint16][]methods.TracerouteHop, error) {
tr.opConfig.ctx, tr.opConfig.cancel = context.WithCancel(context.Background())
tr.results = results{
inflightRequests: sync.Map{},
concurrentRequests: parallel_limiter.New(int(tr.trcrtConfig.ParallelRequests)),
results: map[uint16][]methods.TracerouteHop{},
reachedFinalHop: signal.New(),
}
var err error
tr.opConfig.icmpConn, err = icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, err
}
return tr.start()
}
func (tr *Traceroute) addToResult(ttl uint16, hop methods.TracerouteHop) {
tr.results.resultsMu.Lock()
defer tr.results.resultsMu.Unlock()
if tr.results.results[ttl] == nil {
tr.results.results[ttl] = []methods.TracerouteHop{}
}
tr.results.results[ttl] = append(tr.results.results[ttl], hop)
}
func (tr *Traceroute) getUDPConn(try int) (net.IP, int, net.PacketConn) {
srcIP, _ := util.LocalIPPort(tr.opConfig.destIP)
var ipString string
if srcIP == nil {
ipString = ""
} else {
ipString = srcIP.String()
}
udpConn, err := net.ListenPacket("udp", ipString+":0")
if err != nil {
if try > 3 {
log.Fatal(err)
}
return tr.getUDPConn(try + 1)
}
return srcIP, udpConn.LocalAddr().(*net.UDPAddr).Port, udpConn
}
func (tr *Traceroute) sendMessage(ttl uint16) {
srcIP, srcPort, udpConn := tr.getUDPConn(0)
var payload []byte
if tr.opConfig.quic {
payload = quic.GenerateWithRandomIds()
} else {
ipHeader := &layers.IPv4{
SrcIP: srcIP,
DstIP: tr.opConfig.destIP,
Protocol: layers.IPProtocolTCP,
TTL: uint8(ttl),
}
udpHeader := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(tr.trcrtConfig.Port),
}
_ = udpHeader.SetNetworkLayerForChecksum(ipHeader)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, udpHeader, gopacket.Payload("HAJSFJHKAJSHFKJHAJKFHKASHKFHHKAFKHFAHSJK")); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
payload = buf.Bytes()
}
err := ipv4.NewPacketConn(udpConn).SetTTL(int(ttl))
if err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
icmpMsg := make(chan net.Addr, 1)
udpMsg := make(chan net.Addr, 1)
start := time.Now()
if _, err := udpConn.WriteTo(payload, &net.UDPAddr{IP: tr.opConfig.destIP, Port: tr.trcrtConfig.Port}); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
inflight := inflightData{
icmpMsg: icmpMsg,
}
tr.results.inflightRequests.Store(uint16(srcPort), inflight)
go func() {
reply := make([]byte, 1500)
_, peer, err := udpConn.ReadFrom(reply)
if err != nil {
// probably because we closed the connection
return
}
udpMsg <- peer
}()
select {
case peer := <-icmpMsg:
rtt := time.Since(start)
if peer.(*net.IPAddr).IP.Equal(tr.opConfig.destIP) {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(ttl, methods.TracerouteHop{
Success: true,
Address: peer,
TTL: ttl,
RTT: &rtt,
})
case peer := <-udpMsg:
rtt := time.Since(start)
ip := peer.(*net.UDPAddr).IP
if ip.Equal(tr.opConfig.destIP) {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(ttl, methods.TracerouteHop{
Success: true,
Address: &net.IPAddr{IP: ip},
TTL: ttl,
RTT: &rtt,
})
case <-time.After(tr.trcrtConfig.Timeout):
tr.addToResult(ttl, methods.TracerouteHop{
Success: false,
Address: nil,
TTL: ttl,
RTT: nil,
})
}
tr.results.inflightRequests.Delete(uint16(srcPort))
udpConn.Close()
tr.results.concurrentRequests.Finished()
tr.opConfig.wg.Done()
}
func (tr *Traceroute) handleICMPMessage(msg listener_channel.ReceivedMessage, data []byte) {
header, err := methods.GetICMPResponsePayload(data)
if err != nil {
return
}
srcPort := methods.GetUDPSrcPort(header)
val, ok := tr.results.inflightRequests.LoadAndDelete(srcPort)
if !ok {
return
}
request := val.(inflightData)
request.icmpMsg <- msg.Peer
}
func (tr *Traceroute) icmpListener() {
lc := listener_channel.New(tr.opConfig.icmpConn)
defer lc.Stop()
go lc.Start()
for {
select {
case <-tr.opConfig.ctx.Done():
return
case msg := <-lc.Messages:
if msg.N == nil {
continue
}
rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
if err != nil {
log.Println(err)
continue
}
switch rm.Type {
case ipv4.ICMPTypeTimeExceeded:
body := rm.Body.(*icmp.TimeExceeded).Data
tr.handleICMPMessage(msg, body)
case ipv4.ICMPTypeDestinationUnreachable:
body := rm.Body.(*icmp.DstUnreach).Data
tr.handleICMPMessage(msg, body)
default:
log.Println("received icmp message of unknown type", rm.Type)
}
}
}
}
func (tr *Traceroute) sendLoop() {
rand.Seed(time.Now().UTC().UnixNano())
for ttl := uint16(1); ttl <= tr.trcrtConfig.MaxHops; ttl++ {
select {
case <-tr.results.reachedFinalHop.Chan():
return
default:
}
for i := 0; i < int(tr.trcrtConfig.NumMeasurements); i++ {
select {
case <-tr.opConfig.ctx.Done():
return
case <-tr.results.concurrentRequests.Start():
tr.opConfig.wg.Add()
go tr.sendMessage(ttl)
}
}
}
}
func (tr *Traceroute) start() (*map[uint16][]methods.TracerouteHop, error) {
go tr.icmpListener()
wg := taskgroup.New()
tr.opConfig.wg = wg
tr.sendLoop()
wg.Wait()
tr.opConfig.cancel()
tr.opConfig.icmpConn.Close()
if tr.results.err != nil {
return nil, tr.results.err
}
result := methods.ReduceFinalResult(tr.results.results, tr.trcrtConfig.MaxHops, tr.opConfig.destIP)
return &result, tr.results.err
}

View File

@@ -0,0 +1,52 @@
package parallel_limiter
import (
"sync"
)
type ParallelLimiter struct {
maxCount int
mu sync.Mutex
currentRunning int
waiting []chan struct{}
}
func New(count int) *ParallelLimiter {
return &ParallelLimiter{
maxCount: count,
currentRunning: 0,
waiting: []chan struct{}{},
}
}
func (p *ParallelLimiter) Start() chan struct{} {
p.mu.Lock()
if p.currentRunning+1 > p.maxCount {
waitChan := make(chan struct{})
p.waiting = append(p.waiting, waitChan)
p.mu.Unlock()
return waitChan
}
p.currentRunning++
p.mu.Unlock()
instantResolveChan := make(chan struct{})
go func() {
instantResolveChan <- struct{}{}
}()
return instantResolveChan
}
func (p *ParallelLimiter) Finished() {
p.mu.Lock()
if len(p.waiting) > 0 {
first := p.waiting[0]
p.waiting = p.waiting[1:]
first <- struct{}{}
p.currentRunning++
}
p.currentRunning--
p.mu.Unlock()
}

19
signal/signal.go Normal file
View File

@@ -0,0 +1,19 @@
package signal
type Signal struct {
sigChan chan struct{}
}
func New() *Signal {
return &Signal{sigChan: make(chan struct{}, 1)}
}
func (s *Signal) Signal() {
if len(s.sigChan) == 0 {
s.sigChan <- struct{}{}
}
}
func (s *Signal) Chan() chan struct{} {
return s.sigChan
}

45
taskgroup/taskgroup.go Normal file
View File

@@ -0,0 +1,45 @@
package taskgroup
import (
"sync"
)
type TaskGroup struct {
count int
mu sync.Mutex
done []chan struct{}
}
func New() *TaskGroup {
return &TaskGroup{
count: 0,
mu: sync.Mutex{},
done: []chan struct{}{},
}
}
func (t *TaskGroup) Add() {
t.mu.Lock()
defer t.mu.Unlock()
t.count++
}
func (t *TaskGroup) Done() {
t.mu.Lock()
defer t.mu.Unlock()
if t.count-1 == 0 {
for _, doneChannel := range t.done {
doneChannel <- struct{}{}
}
t.done = []chan struct{}{}
}
t.count--
}
func (t *TaskGroup) Wait() {
doneChannel := make(chan struct{})
t.mu.Lock()
t.done = append(t.done, doneChannel)
t.mu.Unlock()
<-doneChannel
}

24
util/util.go Normal file
View File

@@ -0,0 +1,24 @@
package util
import (
"log"
"net"
)
// get the local ip and port based on our destination ip
func LocalIPPort(dstip net.IP) (net.IP, int) {
serverAddr, err := net.ResolveUDPAddr("udp", dstip.String()+":12345")
if err != nil {
log.Fatal(err)
}
// We don't actually connect to anything, but we can determine
// based on our destination ip what source ip we should use.
if con, err := net.DialUDP("udp", nil, serverAddr); err == nil {
defer con.Close()
if udpaddr, ok := con.LocalAddr().(*net.UDPAddr); ok {
return udpaddr.IP, udpaddr.Port
}
}
return nil, -1
}