p2p/nat: limit UPNP request concurrency (#21390)

This adds a lock around requests because some routers can't handle
concurrent requests. Requests are also rate-limited.
 
The Map function request a new mapping exactly when the map timeout
occurs instead of 5 minutes earlier. This should prevent duplicate mappings.
This commit is contained in:
Felix Lange 2020-08-05 09:51:37 +02:00 committed by GitHub
parent 82a9e11058
commit 1d25039ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 25 deletions

View File

@ -91,15 +91,14 @@ func Parse(spec string) (Interface, error) {
} }
const ( const (
mapTimeout = 20 * time.Minute mapTimeout = 10 * time.Minute
mapUpdateInterval = 15 * time.Minute
) )
// Map adds a port mapping on m and keeps it alive until c is closed. // Map adds a port mapping on m and keeps it alive until c is closed.
// This function is typically invoked in its own goroutine. // This function is typically invoked in its own goroutine.
func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) { func Map(m Interface, c <-chan struct{}, protocol string, extport, intport int, name string) {
log := log.New("proto", protocol, "extport", extport, "intport", intport, "interface", m) log := log.New("proto", protocol, "extport", extport, "intport", intport, "interface", m)
refresh := time.NewTimer(mapUpdateInterval) refresh := time.NewTimer(mapTimeout)
defer func() { defer func() {
refresh.Stop() refresh.Stop()
log.Debug("Deleting port mapping") log.Debug("Deleting port mapping")
@ -121,7 +120,7 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na
if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil { if err := m.AddMapping(protocol, extport, intport, name, mapTimeout); err != nil {
log.Debug("Couldn't add port mapping", "err", err) log.Debug("Couldn't add port mapping", "err", err)
} }
refresh.Reset(mapUpdateInterval) refresh.Reset(mapTimeout)
} }
} }
} }

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/huin/goupnp" "github.com/huin/goupnp"
@ -28,12 +29,17 @@ import (
"github.com/huin/goupnp/dcps/internetgateway2" "github.com/huin/goupnp/dcps/internetgateway2"
) )
const soapRequestTimeout = 3 * time.Second const (
soapRequestTimeout = 3 * time.Second
rateLimit = 200 * time.Millisecond
)
type upnp struct { type upnp struct {
dev *goupnp.RootDevice dev *goupnp.RootDevice
service string service string
client upnpClient client upnpClient
mu sync.Mutex
lastReqTime time.Time
} }
type upnpClient interface { type upnpClient interface {
@ -43,8 +49,23 @@ type upnpClient interface {
GetNATRSIPStatus() (sip bool, nat bool, err error) GetNATRSIPStatus() (sip bool, nat bool, err error)
} }
func (n *upnp) natEnabled() bool {
var ok bool
var err error
n.withRateLimit(func() error {
_, ok, err = n.client.GetNATRSIPStatus()
return err
})
return err == nil && ok
}
func (n *upnp) ExternalIP() (addr net.IP, err error) { func (n *upnp) ExternalIP() (addr net.IP, err error) {
ipString, err := n.client.GetExternalIPAddress() var ipString string
n.withRateLimit(func() error {
ipString, err = n.client.GetExternalIPAddress()
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,7 +84,10 @@ func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, li
protocol = strings.ToUpper(protocol) protocol = strings.ToUpper(protocol)
lifetimeS := uint32(lifetime / time.Second) lifetimeS := uint32(lifetime / time.Second)
n.DeleteMapping(protocol, extport, intport) n.DeleteMapping(protocol, extport, intport)
return n.withRateLimit(func() error {
return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS) return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
})
} }
func (n *upnp) internalAddress() (net.IP, error) { func (n *upnp) internalAddress() (net.IP, error) {
@ -90,36 +114,51 @@ func (n *upnp) internalAddress() (net.IP, error) {
} }
func (n *upnp) DeleteMapping(protocol string, extport, intport int) error { func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
return n.withRateLimit(func() error {
return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol)) return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
})
} }
func (n *upnp) String() string { func (n *upnp) String() string {
return "UPNP " + n.service return "UPNP " + n.service
} }
func (n *upnp) withRateLimit(fn func() error) error {
n.mu.Lock()
defer n.mu.Unlock()
lastreq := time.Since(n.lastReqTime)
if lastreq < rateLimit {
time.Sleep(rateLimit - lastreq)
}
err := fn()
n.lastReqTime = time.Now()
return err
}
// discoverUPnP searches for Internet Gateway Devices // discoverUPnP searches for Internet Gateway Devices
// and returns the first one it can find on the local network. // and returns the first one it can find on the local network.
func discoverUPnP() Interface { func discoverUPnP() Interface {
found := make(chan *upnp, 2) found := make(chan *upnp, 2)
// IGDv1 // IGDv1
go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType { switch sc.Service.ServiceType {
case internetgateway1.URN_WANIPConnection_1: case internetgateway1.URN_WANIPConnection_1:
return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{ServiceClient: sc}} return &upnp{service: "IGDv1-IP1", client: &internetgateway1.WANIPConnection1{ServiceClient: sc}}
case internetgateway1.URN_WANPPPConnection_1: case internetgateway1.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{ServiceClient: sc}} return &upnp{service: "IGDv1-PPP1", client: &internetgateway1.WANPPPConnection1{ServiceClient: sc}}
} }
return nil return nil
}) })
// IGDv2 // IGDv2
go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp { go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType { switch sc.Service.ServiceType {
case internetgateway2.URN_WANIPConnection_1: case internetgateway2.URN_WANIPConnection_1:
return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{ServiceClient: sc}} return &upnp{service: "IGDv2-IP1", client: &internetgateway2.WANIPConnection1{ServiceClient: sc}}
case internetgateway2.URN_WANIPConnection_2: case internetgateway2.URN_WANIPConnection_2:
return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{ServiceClient: sc}} return &upnp{service: "IGDv2-IP2", client: &internetgateway2.WANIPConnection2{ServiceClient: sc}}
case internetgateway2.URN_WANPPPConnection_1: case internetgateway2.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{ServiceClient: sc}} return &upnp{service: "IGDv2-PPP1", client: &internetgateway2.WANPPPConnection1{ServiceClient: sc}}
} }
return nil return nil
}) })
@ -134,7 +173,7 @@ func discoverUPnP() Interface {
// finds devices matching the given target and calls matcher for all // finds devices matching the given target and calls matcher for all
// advertised services of each device. The first non-nil service found // advertised services of each device. The first non-nil service found
// is sent into out. If no service matched, nil is sent. // is sent into out. If no service matched, nil is sent.
func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) { func discover(out chan<- *upnp, target string, matcher func(goupnp.ServiceClient) *upnp) {
devs, err := goupnp.DiscoverDevices(target) devs, err := goupnp.DiscoverDevices(target)
if err != nil { if err != nil {
out <- nil out <- nil
@ -157,16 +196,17 @@ func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice,
Service: service, Service: service,
} }
sc.SOAPClient.HTTPClient.Timeout = soapRequestTimeout sc.SOAPClient.HTTPClient.Timeout = soapRequestTimeout
upnp := matcher(devs[i].Root, sc) upnp := matcher(sc)
if upnp == nil { if upnp == nil {
return return
} }
upnp.dev = devs[i].Root
// check whether port mapping is enabled // check whether port mapping is enabled
if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat { if upnp.natEnabled() {
return
}
out <- upnp out <- upnp
found = true found = true
}
}) })
} }
if !found { if !found {