improve: plugin hook for TTLComplete

This commit is contained in:
sjlleo
2023-09-07 22:20:55 +08:00
parent bb522ed859
commit c532dfd05c
6 changed files with 51 additions and 26 deletions

View File

@@ -39,8 +39,9 @@ func (t *ICMPTracer) SetConfig(c Config) {
}
func (t *ICMPTracer) Execute() (*Result, error) {
t.foundIPs = make(map[string]bool)
t.inflightRequestRWLock.Lock()
t.foundIPs = make(map[string]bool)
t.inflightRequest = make(map[int]chan Hop)
t.inflightRequestRWLock.Unlock()
@@ -64,8 +65,6 @@ func (t *ICMPTracer) Execute() (*Result, error) {
go t.listenICMP()
for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ {
for _, plugin := range t.Plugins {
// if use Hook
//plgn.ExecuteHook(plugin, "OnTTLChange", ttl)
plugin.OnTTLChange(ttl)
}
t.inflightRequestRWLock.Lock()
@@ -79,7 +78,6 @@ func (t *ICMPTracer) Execute() (*Result, error) {
go t.send(ttl)
<-time.After(t.Config.PacketInterval)
}
<-time.After(t.Config.TTLInterval)
}
@@ -257,10 +255,12 @@ func (t *ICMPTracer) send(ttl int) error {
rtt := time.Since(start)
ipStr := h.Address.String()
if !t.foundIPs[ipStr] {
t.inflightRequestRWLock.Lock()
t.foundIPs[ipStr] = true
t.inflightRequestRWLock.Unlock()
// Trigger 所有插件的 OnIPFound 方法
for _, plugin := range t.Plugins {
plugin.OnIPFound(h.Address)
plugin.OnNewIPFound(h.Address)
}
}
if t.final != -1 && ttl > t.final {
@@ -299,5 +299,8 @@ func (t *ICMPTracer) send(ttl int) error {
}
for _, plugin := range t.Plugins {
plugin.OnTTLCompleted(len(t.res.Hops), t.res.Hops[len(t.res.Hops)-1])
}
return nil
}

View File

@@ -26,6 +26,7 @@ type ICMPTracerv6 struct {
icmpListen net.PacketConn
final int
finalLock sync.Mutex
foundIPs map[string]bool
}
func (t *ICMPTracerv6) GetConfig() *Config {
@@ -61,7 +62,11 @@ func (t *ICMPTracerv6) Execute() (*Result, error) {
go t.listenICMP()
for ttl := t.BeginHop; ttl <= t.MaxHops; ttl++ {
for _, plugin := range t.Plugins {
plugin.OnTTLChange(ttl)
}
t.inflightRequestRWLock.Lock()
t.foundIPs = make(map[string]bool)
t.inflightRequest[ttl] = make(chan Hop, t.NumMeasurements)
t.inflightRequestRWLock.Unlock()
if t.final != -1 && ttl > t.final {

View File

@@ -7,8 +7,6 @@ import (
"net"
"sync"
"time"
"github.com/sjlleo/nexttrace-core/plgn"
)
var (
@@ -24,6 +22,13 @@ type TraceInstance struct {
ErrorStr string
}
type Plugin interface {
OnDNSResolve(domain string) (net.IP, error)
OnNewIPFound(ip net.Addr) error
OnTTLChange(ttl int) error
OnTTLCompleted(ttl int, hop []Hop) error
}
type Config struct {
TraceMethod Method
SrcAddr string
@@ -37,7 +42,7 @@ type Config struct {
Quic bool
PacketInterval time.Duration
TTLInterval time.Duration
Plugins []plgn.Plugin
Plugins []Plugin
}
const (
@@ -65,12 +70,12 @@ type Hop struct {
Error error
}
func Traceroute(p []plgn.Plugin) {
func Traceroute(p []Plugin) {
var test_config = Config{
DestIP: net.IPv4(1, 1, 1, 1),
DestPort: 443,
ParallelRequests: 30,
NumMeasurements: 1,
NumMeasurements: 3,
BeginHop: 1,
MaxHops: 30,
TTLInterval: 1 * time.Millisecond,

View File

@@ -3,6 +3,8 @@ package plgn
import (
"fmt"
"net"
"github.com/sjlleo/nexttrace-core/core"
)
type DebugPlugin struct {
@@ -10,7 +12,7 @@ type DebugPlugin struct {
DebugLevel int
}
func NewDebugPlugin(params interface{}) Plugin {
func NewDebugPlugin(params interface{}) core.Plugin {
debugLevel, ok := params.(int)
if !ok {
return nil
@@ -25,9 +27,16 @@ func (d *DebugPlugin) OnTTLChange(ttl int) error {
return nil
}
func (d *DebugPlugin) OnIPFound(ip net.Addr) error {
func (d *DebugPlugin) OnNewIPFound(ip net.Addr) error {
if d.DebugLevel <= 2 {
fmt.Println("Debug Level 2: New IP Found: ", ip)
}
return nil
}
func (d *DebugPlugin) OnTTLCompleted(ttl int, hop []core.Hop) error {
if d.DebugLevel <= 2 {
fmt.Println("Debug Level 2: ttl=", ttl, "Hop:", hop)
}
return nil
}

View File

@@ -1,6 +1,10 @@
package plgn
import "net"
import (
"net"
"github.com/sjlleo/nexttrace-core/core"
)
type DefaultPlugin struct {
}
@@ -9,10 +13,14 @@ func (d *DefaultPlugin) OnDNSResolve(domain string) (net.IP, error) {
return nil, nil
}
func (d *DefaultPlugin) OnIPFound(ip net.Addr) error {
func (d *DefaultPlugin) OnNewIPFound(ip net.Addr) error {
return nil
}
func (d *DefaultPlugin) OnTTLChange(ttl int) error {
return nil
}
func (d *DefaultPlugin) OnTTLCompleted(ttl int, hop []core.Hop) error {
return nil
}

View File

@@ -2,25 +2,20 @@ package plgn
import (
"log"
"net"
"reflect"
"strings"
"github.com/sjlleo/nexttrace-core/core"
)
type Plugin interface {
OnDNSResolve(domain string) (net.IP, error)
OnIPFound(ip net.Addr) error
OnTTLChange(ttl int) error
}
var pluginRegistry = make(map[string]func(interface{}) core.Plugin)
var pluginRegistry = make(map[string]func(interface{}) Plugin)
func RegisterPlugin(name string, constructor func(interface{}) Plugin) {
func RegisterPlugin(name string, constructor func(interface{}) core.Plugin) {
pluginRegistry[name] = constructor
}
func CreatePlugins(enabledPlugins string, params interface{}) []Plugin {
var plugins []Plugin
func CreatePlugins(enabledPlugins string, params interface{}) []core.Plugin {
var plugins []core.Plugin
for _, name := range strings.Split(enabledPlugins, ",") {
if constructor, exists := pluginRegistry[name]; exists {
plugins = append(plugins, constructor(params))
@@ -29,7 +24,7 @@ func CreatePlugins(enabledPlugins string, params interface{}) []Plugin {
return plugins
}
func ExecuteHook(plugin Plugin, hookName string, args ...interface{}) {
func ExecuteHook(plugin core.Plugin, hookName string, args ...interface{}) {
v := reflect.ValueOf(plugin)
method := v.MethodByName(hookName)