Files
NTrace-core/methods/tcp/tcp.go
2022-04-24 10:07:01 +08:00

329 lines
7.4 KiB
Go

package tcp
import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"traceroute/listener_channel"
"traceroute/methods"
"traceroute/parallel_limiter"
"traceroute/signal"
"traceroute/util"
"golang.org/x/net/context"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"log"
"math"
"math/rand"
"net"
"sync"
"time"
)
type inflightData struct {
start time.Time
ttl uint16
}
type results struct {
inflightRequests sync.Map
results map[uint16][]methods.TracerouteHop
resultsMu sync.Mutex
err error
concurrentRequests *parallel_limiter.ParallelLimiter
reachedFinalHop *signal.Signal
}
type Traceroute struct {
opConfig opConfig
trcrtConfig methods.TracerouteConfig
results results
}
type opConfig struct {
icmpConn net.PacketConn
tcpConn net.PacketConn
tcpMu sync.Mutex
destIP net.IP
srcIP net.IP
wg *sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func New(destIP net.IP, config methods.TracerouteConfig) *Traceroute {
return &Traceroute{
opConfig: opConfig{
destIP: destIP,
},
trcrtConfig: config,
}
}
func (tr *Traceroute) Start() (*map[uint16][]methods.TracerouteHop, error) {
tr.opConfig.ctx, tr.opConfig.cancel = context.WithCancel(context.Background())
tr.opConfig.srcIP, _ = util.LocalIPPort(tr.opConfig.destIP)
var err error
tr.opConfig.tcpConn, err = net.ListenPacket("ip4:tcp", tr.opConfig.srcIP.String())
if err != nil {
return nil, err
}
tr.opConfig.icmpConn, err = icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
tr.opConfig.wg = &wg
tr.results = results{
inflightRequests: sync.Map{},
concurrentRequests: parallel_limiter.New(int(tr.trcrtConfig.ParallelRequests)),
reachedFinalHop: signal.New(),
results: map[uint16][]methods.TracerouteHop{},
}
return tr.start()
}
func (tr *Traceroute) timeoutLoop() {
ticker := time.NewTicker(tr.trcrtConfig.Timeout / 4)
go func() {
for range ticker.C {
tr.results.inflightRequests.Range(func(key, value interface{}) bool {
request := value.(inflightData)
expired := time.Since(request.start) > tr.trcrtConfig.Timeout
if !expired {
return true
}
tr.results.inflightRequests.Delete(key)
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: false,
TTL: request.ttl,
})
tr.results.concurrentRequests.Finished()
tr.opConfig.wg.Done()
return true
})
}
}()
select {
case <-tr.opConfig.ctx.Done():
ticker.Stop()
}
}
func (tr *Traceroute) addToResult(ttl uint16, hop methods.TracerouteHop) {
tr.results.resultsMu.Lock()
defer tr.results.resultsMu.Unlock()
if tr.results.results[ttl] == nil {
tr.results.results[ttl] = []methods.TracerouteHop{}
}
tr.results.results[ttl] = append(tr.results.results[ttl], hop)
}
func (tr *Traceroute) handleICMPMessage(msg listener_channel.ReceivedMessage, data []byte) {
header, err := methods.GetICMPResponsePayload(data)
if err != nil {
return
}
sequenceNumber := methods.GetTCPSeq(header)
val, ok := tr.results.inflightRequests.LoadAndDelete(sequenceNumber)
if !ok {
return
}
request := val.(inflightData)
elapsed := time.Since(request.start)
if msg.Peer.String() == tr.opConfig.destIP.String() {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: true,
Address: msg.Peer,
TTL: request.ttl,
RTT: &elapsed,
})
tr.results.concurrentRequests.Finished()
tr.opConfig.wg.Done()
}
func (tr *Traceroute) icmpListener() {
lc := listener_channel.New(tr.opConfig.icmpConn)
defer lc.Stop()
go lc.Start()
for {
select {
case <-tr.opConfig.ctx.Done():
return
case msg := <-lc.Messages:
if msg.N == nil {
continue
}
rm, err := icmp.ParseMessage(1, msg.Msg[:*msg.N])
if err != nil {
log.Println(err)
continue
}
switch rm.Type {
case ipv4.ICMPTypeTimeExceeded:
body := rm.Body.(*icmp.TimeExceeded).Data
tr.handleICMPMessage(msg, body)
case ipv4.ICMPTypeDestinationUnreachable:
body := rm.Body.(*icmp.DstUnreach).Data
tr.handleICMPMessage(msg, body)
default:
log.Println("received icmp message of unknown type")
}
}
}
}
func (tr *Traceroute) tcpListener() {
lc := listener_channel.New(tr.opConfig.tcpConn)
defer lc.Stop()
go lc.Start()
for {
select {
case <-tr.opConfig.ctx.Done():
return
case msg := <-lc.Messages:
if msg.N == nil {
continue
}
if msg.Peer.String() != tr.opConfig.destIP.String() {
continue
}
// Decode a packet
packet := gopacket.NewPacket(msg.Msg[:*msg.N], layers.LayerTypeTCP, gopacket.Default)
// Get the TCP layer from this packet
if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
tcp, _ := tcpLayer.(*layers.TCP)
val, ok := tr.results.inflightRequests.LoadAndDelete(tcp.Ack - 1)
if !ok {
continue
}
request := val.(inflightData)
tr.results.concurrentRequests.Finished()
elapsed := time.Since(request.start)
if msg.Peer.String() == tr.opConfig.destIP.String() {
tr.results.reachedFinalHop.Signal()
}
tr.addToResult(request.ttl, methods.TracerouteHop{
Success: true,
Address: msg.Peer,
TTL: request.ttl,
RTT: &elapsed,
})
tr.opConfig.wg.Done()
}
}
}
}
func (tr *Traceroute) sendMessage(ttl uint16) {
_, srcPort := util.LocalIPPort(tr.opConfig.destIP)
ipHeader := &layers.IPv4{
SrcIP: tr.opConfig.srcIP,
DstIP: tr.opConfig.destIP,
Protocol: layers.IPProtocolTCP,
TTL: uint8(ttl),
}
sequenceNumber := uint32(rand.Intn(math.MaxUint32))
tcpHeader := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(tr.trcrtConfig.Port),
Seq: sequenceNumber,
SYN: true,
Window: 14600,
}
_ = tcpHeader.SetNetworkLayerForChecksum(ipHeader)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, tcpHeader); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
tr.opConfig.tcpMu.Lock()
defer tr.opConfig.tcpMu.Unlock()
err := ipv4.NewPacketConn(tr.opConfig.tcpConn).SetTTL(int(ttl))
if err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
start := time.Now()
if _, err := tr.opConfig.tcpConn.WriteTo(buf.Bytes(), &net.IPAddr{IP: tr.opConfig.destIP}); err != nil {
tr.results.err = err
tr.opConfig.cancel()
return
}
tr.results.inflightRequests.Store(sequenceNumber, inflightData{start: start, ttl: ttl})
}
func (tr *Traceroute) sendLoop() {
rand.Seed(time.Now().UTC().UnixNano())
defer tr.opConfig.wg.Done()
for ttl := uint16(1); ttl <= tr.trcrtConfig.MaxHops; ttl++ {
select {
case <-tr.results.reachedFinalHop.Chan():
return
default:
}
for i := 0; i < int(tr.trcrtConfig.NumMeasurements); i++ {
select {
case <-tr.opConfig.ctx.Done():
return
case <-tr.results.concurrentRequests.Start():
tr.opConfig.wg.Add(1)
go tr.sendMessage(ttl)
}
}
}
}
func (tr *Traceroute) start() (*map[uint16][]methods.TracerouteHop, error) {
go tr.timeoutLoop()
go tr.icmpListener()
go tr.tcpListener()
tr.opConfig.wg.Add(1)
go tr.sendLoop()
tr.opConfig.wg.Wait()
tr.opConfig.cancel()
if tr.results.err != nil {
return nil, tr.results.err
}
result := methods.ReduceFinalResult(tr.results.results, tr.trcrtConfig.MaxHops, tr.opConfig.destIP)
return &result, tr.results.err
}