275 lines
6.8 KiB
Go
275 lines
6.8 KiB
Go
|
package quic
|
||
|
|
||
|
import (
|
||
|
"crypto/hmac"
|
||
|
"crypto/rand"
|
||
|
"crypto/sha256"
|
||
|
"errors"
|
||
|
"hash"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
||
|
"github.com/lucas-clemente/quic-go/internal/wire"
|
||
|
)
|
||
|
|
||
|
// The packetHandlerMap stores packetHandlers, identified by connection ID.
|
||
|
// It is used:
|
||
|
// * by the server to store sessions
|
||
|
// * when multiplexing outgoing connections to store clients
|
||
|
type packetHandlerMap struct {
|
||
|
mutex sync.RWMutex
|
||
|
|
||
|
conn net.PacketConn
|
||
|
connIDLen int
|
||
|
|
||
|
handlers map[string] /* string(ConnectionID)*/ packetHandler
|
||
|
resetTokens map[[16]byte] /* stateless reset token */ packetHandler
|
||
|
server unknownPacketHandler
|
||
|
|
||
|
listening chan struct{} // is closed when listen returns
|
||
|
closed bool
|
||
|
|
||
|
deleteRetiredSessionsAfter time.Duration
|
||
|
|
||
|
statelessResetEnabled bool
|
||
|
statelessResetHasher hash.Hash
|
||
|
|
||
|
logger utils.Logger
|
||
|
}
|
||
|
|
||
|
var _ packetHandlerManager = &packetHandlerMap{}
|
||
|
|
||
|
func newPacketHandlerMap(
|
||
|
conn net.PacketConn,
|
||
|
connIDLen int,
|
||
|
statelessResetKey []byte,
|
||
|
logger utils.Logger,
|
||
|
) packetHandlerManager {
|
||
|
m := &packetHandlerMap{
|
||
|
conn: conn,
|
||
|
connIDLen: connIDLen,
|
||
|
listening: make(chan struct{}),
|
||
|
handlers: make(map[string]packetHandler),
|
||
|
resetTokens: make(map[[16]byte]packetHandler),
|
||
|
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
|
||
|
statelessResetEnabled: len(statelessResetKey) > 0,
|
||
|
statelessResetHasher: hmac.New(sha256.New, statelessResetKey),
|
||
|
logger: logger,
|
||
|
}
|
||
|
go m.listen()
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
||
|
h.mutex.Lock()
|
||
|
h.handlers[string(id)] = handler
|
||
|
h.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||
|
h.removeByConnectionIDAsString(string(id))
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
|
||
|
h.mutex.Lock()
|
||
|
delete(h.handlers, id)
|
||
|
h.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
|
||
|
h.retireByConnectionIDAsString(string(id))
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) retireByConnectionIDAsString(id string) {
|
||
|
time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
|
||
|
h.removeByConnectionIDAsString(id)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) {
|
||
|
h.mutex.Lock()
|
||
|
h.resetTokens[token] = handler
|
||
|
h.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) RemoveResetToken(token [16]byte) {
|
||
|
h.mutex.Lock()
|
||
|
delete(h.resetTokens, token)
|
||
|
h.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
|
||
|
h.mutex.Lock()
|
||
|
h.server = s
|
||
|
h.mutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) CloseServer() {
|
||
|
h.mutex.Lock()
|
||
|
h.server = nil
|
||
|
var wg sync.WaitGroup
|
||
|
for id, handler := range h.handlers {
|
||
|
if handler.getPerspective() == protocol.PerspectiveServer {
|
||
|
wg.Add(1)
|
||
|
go func(id string, handler packetHandler) {
|
||
|
// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
|
||
|
_ = handler.Close()
|
||
|
h.retireByConnectionIDAsString(id)
|
||
|
wg.Done()
|
||
|
}(id, handler)
|
||
|
}
|
||
|
}
|
||
|
h.mutex.Unlock()
|
||
|
wg.Wait()
|
||
|
}
|
||
|
|
||
|
// Close the underlying connection and wait until listen() has returned.
|
||
|
func (h *packetHandlerMap) Close() error {
|
||
|
if err := h.conn.Close(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
<-h.listening // wait until listening returns
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) close(e error) error {
|
||
|
h.mutex.Lock()
|
||
|
if h.closed {
|
||
|
h.mutex.Unlock()
|
||
|
return nil
|
||
|
}
|
||
|
h.closed = true
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
for _, handler := range h.handlers {
|
||
|
wg.Add(1)
|
||
|
go func(handler packetHandler) {
|
||
|
handler.destroy(e)
|
||
|
wg.Done()
|
||
|
}(handler)
|
||
|
}
|
||
|
|
||
|
if h.server != nil {
|
||
|
h.server.closeWithError(e)
|
||
|
}
|
||
|
h.mutex.Unlock()
|
||
|
wg.Wait()
|
||
|
return getMultiplexer().RemoveConn(h.conn)
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) listen() {
|
||
|
defer close(h.listening)
|
||
|
for {
|
||
|
buffer := getPacketBuffer()
|
||
|
data := buffer.Slice
|
||
|
// The packet size should not exceed protocol.MaxReceivePacketSize bytes
|
||
|
// If it does, we only read a truncated packet, which will then end up undecryptable
|
||
|
n, addr, err := h.conn.ReadFrom(data)
|
||
|
if err != nil {
|
||
|
h.close(err)
|
||
|
return
|
||
|
}
|
||
|
h.handlePacket(addr, buffer, data[:n])
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) handlePacket(
|
||
|
addr net.Addr,
|
||
|
buffer *packetBuffer,
|
||
|
data []byte,
|
||
|
) {
|
||
|
connID, err := wire.ParseConnectionID(data, h.connIDLen)
|
||
|
if err != nil {
|
||
|
h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err)
|
||
|
return
|
||
|
}
|
||
|
rcvTime := time.Now()
|
||
|
|
||
|
h.mutex.RLock()
|
||
|
defer h.mutex.RUnlock()
|
||
|
|
||
|
if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
handler, handlerFound := h.handlers[string(connID)]
|
||
|
|
||
|
p := &receivedPacket{
|
||
|
remoteAddr: addr,
|
||
|
rcvTime: rcvTime,
|
||
|
buffer: buffer,
|
||
|
data: data,
|
||
|
}
|
||
|
if handlerFound { // existing session
|
||
|
handler.handlePacket(p)
|
||
|
return
|
||
|
}
|
||
|
if data[0]&0x80 == 0 {
|
||
|
go h.maybeSendStatelessReset(p, connID)
|
||
|
return
|
||
|
}
|
||
|
if h.server == nil { // no server set
|
||
|
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
|
||
|
return
|
||
|
}
|
||
|
h.server.handlePacket(p)
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
|
||
|
// stateless resets are always short header packets
|
||
|
if data[0]&0x80 != 0 {
|
||
|
return false
|
||
|
}
|
||
|
if len(data) < protocol.MinStatelessResetSize {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
var token [16]byte
|
||
|
copy(token[:], data[len(data)-16:])
|
||
|
if sess, ok := h.resetTokens[token]; ok {
|
||
|
h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token)
|
||
|
go sess.destroy(errors.New("received a stateless reset"))
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte {
|
||
|
var token [16]byte
|
||
|
if !h.statelessResetEnabled {
|
||
|
// Return a random stateless reset token.
|
||
|
// This token will be sent in the server's transport parameters.
|
||
|
// By using a random token, an off-path attacker won't be able to disrupt the connection.
|
||
|
rand.Read(token[:])
|
||
|
return token
|
||
|
}
|
||
|
h.statelessResetHasher.Write(connID.Bytes())
|
||
|
copy(token[:], h.statelessResetHasher.Sum(nil))
|
||
|
h.statelessResetHasher.Reset()
|
||
|
return token
|
||
|
}
|
||
|
|
||
|
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
|
||
|
defer p.buffer.Release()
|
||
|
if !h.statelessResetEnabled {
|
||
|
return
|
||
|
}
|
||
|
// Don't send a stateless reset in response to very small packets.
|
||
|
// This includes packets that could be stateless resets.
|
||
|
if len(p.data) <= protocol.MinStatelessResetSize {
|
||
|
return
|
||
|
}
|
||
|
token := h.GetStatelessResetToken(connID)
|
||
|
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
|
||
|
data := make([]byte, 23)
|
||
|
rand.Read(data)
|
||
|
data[0] = (data[0] & 0x7f) | 0x40
|
||
|
data = append(data, token[:]...)
|
||
|
if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil {
|
||
|
h.logger.Debugf("Error sending Stateless Reset: %s", err)
|
||
|
}
|
||
|
}
|