From b20c4b74cc6c1b86b1ff08c7b057f80b4adda1d5 Mon Sep 17 00:00:00 2001 From: Yunlq <140733746+Yunlq@users.noreply.github.com> Date: Wed, 16 Apr 2025 02:54:19 +0800 Subject: [PATCH] add support for custom source ports and optimize some code --- cmd/cmd.go | 8 +++++--- trace/tcp_ipv4.go | 7 ++++++- trace/tcp_ipv6.go | 7 ++++++- trace/trace.go | 1 + trace/udp_ipv4.go | 7 +++++-- trace/udp_ipv6.go | 7 +++++-- util/util.go | 24 ++++++++++++------------ 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 8d11558..7e63c82 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -57,12 +57,13 @@ func Excute() { disableMaptrace := parser.Flag("M", "map", &argparse.Options{Help: "Disable Print Trace Map"}) disableMPLS := parser.Flag("e", "disable-mpls", &argparse.Options{Help: "Disable MPLS"}) ver := parser.Flag("v", "version", &argparse.Options{Help: "Print version info and exit"}) - srcAddr := parser.String("s", "source", &argparse.Options{Help: "Use source src_addr for outgoing packets"}) + srcAddr := parser.String("s", "source", &argparse.Options{Help: "Use source address src_addr for outgoing packets"}) + srcPort := parser.Int("", "source-port", &argparse.Options{Help: "Use source port src_port for outgoing packets (TCP and UDP only)"}) srcDev := parser.String("D", "dev", &argparse.Options{Help: "Use the following Network Devices as the source address in outgoing packets"}) //router := parser.Flag("R", "route", &argparse.Options{Help: "Show Routing Table [Provided By BGP.Tools]"}) - packetInterval := parser.Int("z", "send-time", &argparse.Options{Default: 50, Help: "Set how many [milliseconds] between sending each packet.. Useful when some routers use rate-limit for ICMP messages"}) + packetInterval := parser.Int("z", "send-time", &argparse.Options{Default: 50, Help: "Set how many [milliseconds] between sending each packet. Useful when some routers use rate-limit for ICMP messages"}) ttlInterval := parser.Int("i", "ttl-time", &argparse.Options{Default: 50, Help: "Set how many [milliseconds] between sending packets groups by TTL. Useful when some routers use rate-limit for ICMP messages"}) - timeout := parser.Int("", "timeout", &argparse.Options{Default: 1000, Help: "The number of [milliseconds] to keep probe sockets open before giving up on the connection."}) + timeout := parser.Int("", "timeout", &argparse.Options{Default: 1000, Help: "The number of [milliseconds] to keep probe sockets open before giving up on the connection"}) packetSize := parser.Int("", "psize", &argparse.Options{Default: 52, Help: "Set the payload size"}) str := parser.StringPositional(&argparse.Options{Help: "IP Address or domain name"}) dot := parser.Selector("", "dot-server", []string{"dnssb", "aliyun", "dnspod", "google", "cloudflare"}, &argparse.Options{ @@ -255,6 +256,7 @@ func Excute() { var conf = trace.Config{ DN42: *dn42, SrcAddr: *srcAddr, + SrcPort: *srcPort, BeginHop: *beginHop, DestIP: ip, DestPort: *port, diff --git a/trace/tcp_ipv4.go b/trace/tcp_ipv4.go index 01e27c5..5eaa7e0 100644 --- a/trace/tcp_ipv4.go +++ b/trace/tcp_ipv4.go @@ -211,7 +211,12 @@ func (t *TCPTracer) send(ttl int) error { } // 随机种子 r := rand.New(rand.NewSource(time.Now().UnixNano())) - _, srcPort := util.LocalIPPort(t.DestIP) + _, srcPort := func() (net.IP, int) { + if util.EnvRandomPort == "" && t.SrcPort != 0 { + return nil, t.SrcPort + } + return util.LocalIPPort(t.DestIP) + }() ipHeader := &layers.IPv4{ SrcIP: t.SrcIP, DstIP: t.DestIP, diff --git a/trace/tcp_ipv6.go b/trace/tcp_ipv6.go index dc462a9..294640f 100644 --- a/trace/tcp_ipv6.go +++ b/trace/tcp_ipv6.go @@ -200,7 +200,12 @@ func (t *TCPTracerIPv6) send(ttl int) error { } // 随机种子 r := rand.New(rand.NewSource(time.Now().UnixNano())) - _, srcPort := util.LocalIPPortv6(t.DestIP) + _, srcPort := func() (net.IP, int) { + if util.EnvRandomPort == "" && t.SrcPort != 0 { + return nil, t.SrcPort + } + return util.LocalIPPortv6(t.DestIP) + }() ipHeader := &layers.IPv6{ SrcIP: t.SrcIP, DstIP: t.DestIP, diff --git a/trace/trace.go b/trace/trace.go index fc92af9..3df1d2b 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -23,6 +23,7 @@ var ( type Config struct { SrcAddr string + SrcPort int BeginHop int MaxHops int NumMeasurements int diff --git a/trace/udp_ipv4.go b/trace/udp_ipv4.go index 06cd517..b63a5a2 100644 --- a/trace/udp_ipv4.go +++ b/trace/udp_ipv4.go @@ -149,7 +149,10 @@ func (t *UDPTracer) getUDPConn(try int) (net.IP, int, net.PacketConn, error) { } // Check environment variable to decide caching behavior - if util.GetenvDefault("NEXTTRACE_RANDOMPORT", "") == "" { + if util.EnvRandomPort == "" { + if t.SrcPort != 0 { + cachedLocalPort = t.SrcPort + } // Use cached random port logic if cachedLocalPort == 0 { // First time: listen on a random port @@ -193,7 +196,7 @@ func (t *UDPTracer) send(ttl int) error { return nil } - if util.GetenvDefault("NEXTTRACE_RANDOMPORT", "") == "" { + if util.EnvRandomPort == "" { t.udpMutex.Lock() defer t.udpMutex.Unlock() } diff --git a/trace/udp_ipv6.go b/trace/udp_ipv6.go index 652e04f..15bc596 100644 --- a/trace/udp_ipv6.go +++ b/trace/udp_ipv6.go @@ -214,7 +214,10 @@ func (t *UDPTracerIPv6) getUDPConn(try int) (net.IP, int, net.PacketConn, error) } // Check environment variable to decide caching behavior - if util.GetenvDefault("NEXTTRACE_RANDOMPORT", "") == "" { + if util.EnvRandomPort == "" { + if t.SrcPort != 0 { + cachedLocalPortv6 = t.SrcPort + } // Use cached random port logic if cachedLocalPortv6 == 0 { // First time: listen on a random port @@ -258,7 +261,7 @@ func (t *UDPTracerIPv6) send(ttl int) error { return nil } - if util.GetenvDefault("NEXTTRACE_RANDOMPORT", "") == "" { + if util.EnvRandomPort == "" { t.udpMutex.Lock() defer t.udpMutex.Unlock() } diff --git a/util/util.go b/util/util.go index bb75c16..b13929f 100644 --- a/util/util.go +++ b/util/util.go @@ -16,15 +16,16 @@ import ( "github.com/fatih/color" ) -var Uninterrupted = GetenvDefault("NEXTTRACE_UNINTERRUPTED", "") -var EnvToken = GetenvDefault("NEXTTRACE_TOKEN", "") -var EnvIPInfoLocalPath = GetenvDefault("NEXTTRACE_IPINFOLOCALPATH", "") -var UserAgent = fmt.Sprintf("NextTrace %s/%s/%s", config.Version, runtime.GOOS, runtime.GOARCH) -var RdnsCache sync.Map -var PowProviderParam = "" var DisableMPLS = GetenvDefault("NEXTTRACE_DISABLEMPLS", "") var EnableHidDstIP = GetenvDefault("NEXTTRACE_ENABLEHIDDENDSTIP", "") +var EnvIPInfoLocalPath = GetenvDefault("NEXTTRACE_IPINFOLOCALPATH", "") +var EnvRandomPort = GetenvDefault("NEXTTRACE_RANDOMPORT", "") +var EnvToken = GetenvDefault("NEXTTRACE_TOKEN", "") +var Uninterrupted = GetenvDefault("NEXTTRACE_UNINTERRUPTED", "") var DestIP string +var PowProviderParam = "" +var RdnsCache sync.Map +var UserAgent = fmt.Sprintf("NextTrace %s/%s/%s", config.Version, runtime.GOOS, runtime.GOARCH) var cachedLocalIP net.IP var cachedLocalPort int var localIPOnce sync.Once @@ -84,10 +85,10 @@ func getLocalIPPortv6(dstip net.IP) (net.IP, int) { return nil, -1 } -// LocalIPPort returns the local IP and port based on our destination IP, with caching unless NEXTTRACE_RANDOMPORT is set. +// LocalIPPort returns the local IP and port based on our destination IP, with caching unless EnvRandomPort is set. func LocalIPPort(dstip net.IP) (net.IP, int) { - // If NEXTTRACE_RANDOMPORT is set, bypass caching and return a new port every time. - if GetenvDefault("NEXTTRACE_RANDOMPORT", "") != "" { + // If EnvRandomPort is set, bypass caching and return a new port every time. + if EnvRandomPort != "" { return getLocalIPPort(dstip) } @@ -102,9 +103,8 @@ func LocalIPPort(dstip net.IP) (net.IP, int) { } func LocalIPPortv6(dstip net.IP) (net.IP, int) { - // If NEXTTRACE_RANDOMPORT is set, bypass caching and return a new port every time. - // 该ENV仅对TCP Mode有效,UDP Mode暂无办法固定Port - if GetenvDefault("NEXTTRACE_RANDOMPORT", "") != "" { + // If EnvRandomPort is set, bypass caching and return a new port every time. + if EnvRandomPort != "" { return getLocalIPPortv6(dstip) }