From 1460ad67c047609bcdb33e018ed25dd4eb3dca81 Mon Sep 17 00:00:00 2001 From: Leo Date: Sun, 11 Jun 2023 14:49:16 +0800 Subject: [PATCH] refactor: nexttrace core --- example/traceroute_test.go | 39 +++++++++++ trace/icmp_ipv4.go | 23 +++---- trace/icmp_ipv6.go | 29 +++----- trace/tcp_ipv4.go | 22 +++---- trace/tcp_ipv6.go | 13 +++- trace/trace.go | 132 +++++++++++++++++++++++++------------ trace/udp_ipv4.go | 12 +++- trace/udp_ipv6.go | 12 +++- 8 files changed, 192 insertions(+), 90 deletions(-) create mode 100644 example/traceroute_test.go diff --git a/example/traceroute_test.go b/example/traceroute_test.go new file mode 100644 index 0000000..566ab9c --- /dev/null +++ b/example/traceroute_test.go @@ -0,0 +1,39 @@ +package example + +import ( + "log" + "net" + "testing" + "time" + + "github.com/sjlleo/nexttrace-core/trace" +) + +func traceroute() { + var test_config = trace.Config{ + DestIP: net.IPv4(1, 1, 1, 1), + DestPort: 443, + ParallelRequests: 30, + NumMeasurements: 1, + BeginHop: 1, + MaxHops: 30, + TTLInterval: 1 * time.Millisecond, + Timeout: 2 * time.Second, + TraceMethod: trace.ICMPTrace, + } + traceInstance, err := trace.NewTracer(test_config) + if err != nil { + log.Println(err) + return + } + + res, err := traceInstance.Traceroute() + if err != nil { + log.Println(err) + } + log.Println(res) +} + +func TestTraceToCloudflareDNS(t *testing.T) { + traceroute() +} diff --git a/trace/icmp_ipv4.go b/trace/icmp_ipv4.go index 6bb5f56..82795d8 100644 --- a/trace/icmp_ipv4.go +++ b/trace/icmp_ipv4.go @@ -28,6 +28,14 @@ type ICMPTracer struct { finalLock sync.Mutex } +func (t *ICMPTracer) GetConfig() *Config { + return &t.Config +} + +func (t *ICMPTracer) SetConfig(c Config) { + t.Config = c +} + func (t *ICMPTracer) Execute() (*Result, error) { t.inflightRequestRWLock.Lock() t.inflightRequest = make(map[int]chan Hop) @@ -51,7 +59,6 @@ func (t *ICMPTracer) Execute() (*Result, error) { t.final = -1 go t.listenICMP() - t.wg.Add(1) for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ { t.inflightRequestRWLock.Lock() t.inflightRequest[ttl] = make(chan Hop, t.NumMeasurements) @@ -62,9 +69,9 @@ func (t *ICMPTracer) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) - <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) + <-time.After(t.Config.PacketInterval) } - <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) + <-time.After(t.Config.TTLInterval) } t.wg.Wait() @@ -190,17 +197,11 @@ func reverseID(id string) (int64, int64, error) { } if parity%2 == 1 { - if id[len(id)-1] == '0' { - // fmt.Println("Parity check passed.") - } else { - // fmt.Println("Parity check failed.") + if id[len(id)-1] != '0' { return 0, 0, errors.New("err") } } else { - if id[len(id)-1] == '1' { - // fmt.Println("Parity check passed.") - } else { - // fmt.Println("Parity check failed.") + if id[len(id)-1] != '1' { return 0, 0, errors.New("err") } } diff --git a/trace/icmp_ipv6.go b/trace/icmp_ipv6.go index 81041ed..8c67bd4 100644 --- a/trace/icmp_ipv6.go +++ b/trace/icmp_ipv6.go @@ -27,6 +27,14 @@ type ICMPTracerv6 struct { finalLock sync.Mutex } +func (t *ICMPTracerv6) GetConfig() *Config { + return &t.Config +} + +func (t *ICMPTracerv6) SetConfig(c Config) { + t.Config = c +} + func (t *ICMPTracerv6) Execute() (*Result, error) { t.inflightRequestRWLock.Lock() t.inflightRequest = make(map[int]chan Hop) @@ -61,28 +69,11 @@ func (t *ICMPTracerv6) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) - <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) + <-time.After(t.Config.PacketInterval) } - <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) + <-time.After(t.Config.TTLInterval) } - // for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ { - // if t.final != -1 && ttl > t.final { - // break - // } - // for i := 0; i < t.NumMeasurements; i++ { - // t.wg.Add(1) - // go t.send(ttl) - // } - // // 一组TTL全部退出(收到应答或者超时终止)以后,再进行下一个TTL的包发送 - // t.wg.Wait() - // if t.RealtimePrinter != nil { - // t.RealtimePrinter(&t.res, ttl-1) - // } - // if t.AsyncPrinter != nil { - // t.AsyncPrinter(&t.res) - // } - // } t.wg.Wait() t.res.reduce(t.final) if t.final == -1 { diff --git a/trace/tcp_ipv4.go b/trace/tcp_ipv4.go index 0c15779..6ca74ea 100644 --- a/trace/tcp_ipv4.go +++ b/trace/tcp_ipv4.go @@ -34,6 +34,14 @@ type TCPTracer struct { sem *semaphore.Weighted } +func (t *TCPTracer) GetConfig() *Config { + return &t.Config +} + +func (t *TCPTracer) SetConfig(c Config) { + t.Config = c +} + func (t *TCPTracer) Execute() (*Result, error) { if len(t.res.Hops) > 0 { return &t.res, ErrTracerouteExecuted @@ -79,9 +87,9 @@ func (t *TCPTracer) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) - + <-time.After(t.Config.PacketInterval) } - time.Sleep(1 * time.Millisecond) + <-time.After(t.Config.TTLInterval) } t.res.reduce(t.final) @@ -229,15 +237,7 @@ func (t *TCPTracer) send(ttl int) error { hopCh := make(chan Hop) t.inflightRequest[int(sequenceNumber)] = hopCh t.inflightRequestLock.Unlock() - /* - // 这里属于 2个Sender,N个Reciever的情况,在哪里关闭Channel都容易导致Panic - defer func() { - t.inflightRequestLock.Lock() - close(hopCh) - delete(t.inflightRequest, srcPort) - t.inflightRequestLock.Unlock() - }() - */ + select { case <-t.ctx.Done(): return nil diff --git a/trace/tcp_ipv6.go b/trace/tcp_ipv6.go index 1f1e44d..1cb10c7 100644 --- a/trace/tcp_ipv6.go +++ b/trace/tcp_ipv6.go @@ -34,13 +34,20 @@ type TCPTracerv6 struct { sem *semaphore.Weighted } +func (t *TCPTracerv6) GetConfig() *Config { + return &t.Config +} + +func (t *TCPTracerv6) SetConfig(c Config) { + t.Config = c +} + func (t *TCPTracerv6) Execute() (*Result, error) { if len(t.res.Hops) > 0 { return &t.res, ErrTracerouteExecuted } t.SrcIP, _ = util.LocalIPPortv6(t.DestIP) - // log.Println(util.LocalIPPortv6(t.DestIP)) var err error t.tcp, err = net.ListenPacket("ip6:tcp", t.SrcIP.String()) if err != nil { @@ -71,9 +78,9 @@ func (t *TCPTracerv6) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) + <-time.After(t.Config.PacketInterval) } - time.Sleep(1 * time.Millisecond) - + <-time.After(t.Config.TTLInterval) } t.res.reduce(t.final) diff --git a/trace/trace.go b/trace/trace.go index 6438609..36c71ed 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -2,6 +2,8 @@ package trace import ( "errors" + "fmt" + "log" "net" "sync" "time" @@ -13,7 +15,15 @@ var ( ErrHopLimitTimeout = errors.New("hop timeout") ) +type Method string + +type TraceInstance struct { + Tracer + ErrorStr string +} + type Config struct { + TraceMethod Method SrcAddr string BeginHop int MaxHops int @@ -23,12 +33,10 @@ type Config struct { DestIP net.IP DestPort int Quic bool - PacketInterval int - TTLInterval int + PacketInterval time.Duration + TTLInterval time.Duration } -type Method string - const ( ICMPTrace Method = "icmp" UDPTrace Method = "udp" @@ -37,36 +45,8 @@ const ( type Tracer interface { Execute() (*Result, error) -} - -func Traceroute(method Method, config Config) (*Result, error) { - var tracer Tracer - - switch method { - case ICMPTrace: - if config.DestIP.To4() != nil { - tracer = &ICMPTracer{Config: config} - } else { - tracer = &ICMPTracerv6{Config: config} - } - - case UDPTrace: - if config.DestIP.To4() != nil { - tracer = &UDPTracer{Config: config} - } else { - return nil, errors.New("IPv6 UDP Traceroute is not supported") - } - case TCPTrace: - if config.DestIP.To4() != nil { - tracer = &TCPTracer{Config: config} - } else { - tracer = &TCPTracerv6{Config: config} - // return nil, errors.New("IPv6 TCP Traceroute is not supported") - } - default: - return &Result{}, ErrInvalidMethod - } - return tracer.Execute() + GetConfig() *Config + SetConfig(Config) } type Result struct { @@ -74,6 +54,82 @@ type Result struct { lock sync.Mutex } +type Hop struct { + Address net.Addr + Hostname string + TTL int + RTT time.Duration + Error error +} + +func NewTracer(config Config) (*TraceInstance, error) { + t := TraceInstance{} + switch config.TraceMethod { + case ICMPTrace: + if config.DestIP.To4() != nil { + t.Tracer = &ICMPTracer{Config: config} + } else { + t.Tracer = &ICMPTracerv6{Config: config} + } + + case UDPTrace: + if config.DestIP.To4() != nil { + t.Tracer = &UDPTracer{Config: config} + } else { + t.Tracer = &UDPTracerv6{Config: config} + } + case TCPTrace: + if config.DestIP.To4() != nil { + t.Tracer = &TCPTracer{Config: config} + } else { + t.Tracer = &TCPTracerv6{Config: config} + } + default: + return &TraceInstance{}, ErrInvalidMethod + } + return &t, t.CheckConfig() +} + +func (t *TraceInstance) CheckConfig() (err error) { + c := t.GetConfig() + + configValidConditions := map[string]bool{ + "DestIP is null": c.DestIP == nil, + "BeginHop is empty": c.BeginHop == 0, + "MaxHops is empty": c.MaxHops == 0, + "NumMeasurements is empty": c.NumMeasurements == 0, + "ParallelRequests is empty": c.ParallelRequests == 0, + "Trace Timeout is empty": c.Timeout == 0, + "You must specific at least one of TTLInterval and PacketInterval": c.TTLInterval|c.PacketInterval == 0, + "You choose " + string(c.TraceMethod) + " trace. DestPort must be specified": (c.TraceMethod == TCPTrace || c.TraceMethod == UDPTrace) && c.DestPort == 0, + } + + var ( + inValidFlag bool + ) + + for condition, notValid := range configValidConditions { + if notValid { + inValidFlag = true + t.ErrorStr += fmt.Sprintf("Invalid config: %s\n", condition) + } + } + + if inValidFlag { + return fmt.Errorf(t.ErrorStr) + } + + return nil + +} + +func (t *TraceInstance) Traceroute() (*Result, error) { + if t.ErrorStr != "" { + log.Fatal(t.ErrorStr) + } + return t.Tracer.Execute() +} + func (s *Result) add(hop Hop) { s.lock.Lock() defer s.lock.Unlock() @@ -90,11 +146,3 @@ func (s *Result) reduce(final int) { s.Hops = s.Hops[:final] } } - -type Hop struct { - Address net.Addr - Hostname string - TTL int - RTT time.Duration - Error error -} diff --git a/trace/udp_ipv4.go b/trace/udp_ipv4.go index 0ab508c..ae5062a 100644 --- a/trace/udp_ipv4.go +++ b/trace/udp_ipv4.go @@ -31,6 +31,14 @@ type UDPTracer struct { sem *semaphore.Weighted } +func (t *UDPTracer) GetConfig() *Config { + return &t.Config +} + +func (t *UDPTracer) SetConfig(c Config) { + t.Config = c +} + func (t *UDPTracer) Execute() (*Result, error) { if len(t.res.Hops) > 0 { return &t.res, ErrTracerouteExecuted @@ -60,9 +68,9 @@ func (t *UDPTracer) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) - + <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) } - time.Sleep(1 * time.Millisecond) + <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) } t.res.reduce(t.final) diff --git a/trace/udp_ipv6.go b/trace/udp_ipv6.go index 7c70fbc..525624e 100644 --- a/trace/udp_ipv6.go +++ b/trace/udp_ipv6.go @@ -31,6 +31,14 @@ type UDPTracerv6 struct { sem *semaphore.Weighted } +func (t *UDPTracerv6) GetConfig() *Config { + return &t.Config +} + +func (t *UDPTracerv6) SetConfig(c Config) { + t.Config = c +} + func (t *UDPTracerv6) Execute() (*Result, error) { if len(t.res.Hops) > 0 { return &t.res, ErrTracerouteExecuted @@ -60,9 +68,9 @@ func (t *UDPTracerv6) Execute() (*Result, error) { for i := 0; i < t.NumMeasurements; i++ { t.wg.Add(1) go t.send(ttl) - + <-time.After(time.Millisecond * time.Duration(t.Config.PacketInterval)) } - time.Sleep(1 * time.Millisecond) + <-time.After(time.Millisecond * time.Duration(t.Config.TTLInterval)) } t.res.reduce(t.final)