diff --git a/trace/udp.go b/trace/udp.go index c9bc423..21edcb8 100644 --- a/trace/udp.go +++ b/trace/udp.go @@ -4,6 +4,7 @@ import ( "log" "math/rand" "net" + "strconv" "sync" "time" @@ -135,9 +136,11 @@ func (t *UDPTracer) handleICMPMessage(msg ReceivedMessage, data []byte) { } } -func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn) { - srcIP, _ := util.LocalIPPort(t.DestIP) +var cachedLocalPort int +var localPortOnce sync.Once +func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn, error) { + srcIP, _ := util.LocalIPPort(t.DestIP) var ipString string if srcIP == nil { ipString = "" @@ -145,14 +148,37 @@ func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn) { ipString = srcIP.String() } - udpConn, err := net.ListenPacket("udp", ipString+":0") - if err != nil { - if try > 3 { - log.Fatal(err) + // Check environment variable to decide caching behavior + if util.GetenvDefault("NEXTTRACE_RANDOMPORT", "") == "" { + // Use cached random port logic + if cachedLocalPort == 0 { + // First time: listen on a random port + udpConn, err := net.ListenPacket("udp", ipString+":0") + if err != nil { + if try > 3 { + log.Fatal(err) + } + return srcIP, 0, nil, err + } + cachedLocalPort = udpConn.LocalAddr().(*net.UDPAddr).Port + // Close the initial connection after obtaining the port + udpConn.Close() } - return t.getUDPConn(try + 1) + // Use the cached local port to establish a new connection + udpConn, err := net.ListenPacket("udp", ipString+":"+strconv.Itoa(cachedLocalPort)) + if err != nil { + return srcIP, cachedLocalPort, nil, err + } + return srcIP, cachedLocalPort, udpConn, nil + } else { + // Without caching: create a new connection each time using a new random port + udpConn, err := net.ListenPacket("udp", ipString+":0") + if err != nil { + return srcIP, 0, nil, err + } + localPort := udpConn.LocalAddr().(*net.UDPAddr).Port + return srcIP, localPort, udpConn, nil } - return srcIP, udpConn.LocalAddr().(*net.UDPAddr).Port, udpConn } func (t *UDPTracer) send(ttl int) error { @@ -167,7 +193,10 @@ func (t *UDPTracer) send(ttl int) error { return nil } - srcIP, srcPort, udpConn := t.getUDPConn(0) + srcIP, srcPort, udpConn, err := t.getUDPConn(0) + if err != nil { + return err + } defer udpConn.Close() //var payload []byte