connection rate limiting
This commit is contained in:
parent
c9d3652357
commit
b30548376b
@ -138,11 +138,21 @@ var runCmd = &cli.Command{
|
|||||||
Usage: "rate-limit API calls. Use 0 to disable",
|
Usage: "rate-limit API calls. Use 0 to disable",
|
||||||
Value: 0,
|
Value: 0,
|
||||||
},
|
},
|
||||||
|
&cli.Int64Flag{
|
||||||
|
Name: "per-conn-rate-limit",
|
||||||
|
Usage: "rate-limit API calls per each connection. Use 0 to disable",
|
||||||
|
Value: 0,
|
||||||
|
},
|
||||||
&cli.DurationFlag{
|
&cli.DurationFlag{
|
||||||
Name: "rate-limit-timeout",
|
Name: "rate-limit-timeout",
|
||||||
Usage: "the maximum time to wait for the rate limter before returning an error to clients",
|
Usage: "the maximum time to wait for the rate limter before returning an error to clients",
|
||||||
Value: gateway.DefaultRateLimitTimeout,
|
Value: gateway.DefaultRateLimitTimeout,
|
||||||
},
|
},
|
||||||
|
&cli.Int64Flag{
|
||||||
|
Name: "conn-per-minute",
|
||||||
|
Usage: "The number of incomming connections to accept from a single IP per minute. Use 0 to disable",
|
||||||
|
Value: 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Action: func(cctx *cli.Context) error {
|
Action: func(cctx *cli.Context) error {
|
||||||
log.Info("Starting lotus gateway")
|
log.Info("Starting lotus gateway")
|
||||||
@ -165,7 +175,9 @@ var runCmd = &cli.Command{
|
|||||||
address = cctx.String("listen")
|
address = cctx.String("listen")
|
||||||
waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit"))
|
waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit"))
|
||||||
rateLimit = cctx.Int64("rate-limit")
|
rateLimit = cctx.Int64("rate-limit")
|
||||||
|
perConnRateLimit = cctx.Int64("per-conn-rate-limit")
|
||||||
rateLimitTimeout = cctx.Duration("rate-limit-timeout")
|
rateLimitTimeout = cctx.Duration("rate-limit-timeout")
|
||||||
|
connPerMinute = cctx.Int64("conn-per-minute")
|
||||||
)
|
)
|
||||||
|
|
||||||
serverOptions := make([]jsonrpc.ServerOption, 0)
|
serverOptions := make([]jsonrpc.ServerOption, 0)
|
||||||
@ -186,7 +198,7 @@ var runCmd = &cli.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
gwapi := gateway.NewNode(api, lookbackCap, waitLookback, rateLimit, rateLimitTimeout)
|
gwapi := gateway.NewNode(api, lookbackCap, waitLookback, rateLimit, rateLimitTimeout)
|
||||||
h, err := gateway.Handler(gwapi, serverOptions...)
|
h, err := gateway.Handler(gwapi, perConnRateLimit, connPerMinute, serverOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return xerrors.Errorf("failed to set up gateway HTTP handler")
|
return xerrors.Errorf("failed to set up gateway HTTP handler")
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"contrib.go.opencensus.io/exporter/prometheus"
|
"contrib.go.opencensus.io/exporter/prometheus"
|
||||||
"github.com/filecoin-project/go-jsonrpc"
|
"github.com/filecoin-project/go-jsonrpc"
|
||||||
@ -11,10 +15,11 @@ import (
|
|||||||
"github.com/filecoin-project/lotus/metrics/proxy"
|
"github.com/filecoin-project/lotus/metrics/proxy"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
promclient "github.com/prometheus/client_golang/prometheus"
|
promclient "github.com/prometheus/client_golang/prometheus"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler returns a gateway http.Handler, to be mounted as-is on the server.
|
// Handler returns a gateway http.Handler, to be mounted as-is on the server.
|
||||||
func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error) {
|
func Handler(a api.Gateway, rateLimit int64, connPerMinute int64, opts ...jsonrpc.ServerOption) (http.Handler, error) {
|
||||||
m := mux.NewRouter()
|
m := mux.NewRouter()
|
||||||
|
|
||||||
serveRpc := func(path string, hnd interface{}) {
|
serveRpc := func(path string, hnd interface{}) {
|
||||||
@ -44,5 +49,95 @@ func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error)
|
|||||||
Next: mux.ServeHTTP,
|
Next: mux.ServeHTTP,
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
return m, nil
|
rlh := NewRateLimiterHandler(m, rateLimit)
|
||||||
|
clh := NewConnectionRateLimiterHandler(rlh, connPerMinute)
|
||||||
|
return clh, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRateLimiterHandler(handler http.Handler, rateLimit int64) *RateLimiterHandler {
|
||||||
|
limiter := limiterFromRateLimit(rateLimit)
|
||||||
|
|
||||||
|
return &RateLimiterHandler{
|
||||||
|
handler: handler,
|
||||||
|
limiter: limiter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds a rate limiter to the request context for per-connection rate limiting
|
||||||
|
type RateLimiterHandler struct {
|
||||||
|
handler http.Handler
|
||||||
|
limiter *rate.Limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h RateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r2 := r.WithContext(context.WithValue(r.Context(), "limiter", h.limiter))
|
||||||
|
h.handler.ServeHTTP(w, r2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// this blocks new connections if there have already been too many.
|
||||||
|
func NewConnectionRateLimiterHandler(handler http.Handler, connPerMinute int64) *ConnectionRateLimiterHandler {
|
||||||
|
ipmap := make(map[string]int64)
|
||||||
|
return &ConnectionRateLimiterHandler{
|
||||||
|
ipmap: ipmap,
|
||||||
|
connPerMinute: connPerMinute,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectionRateLimiterHandler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ipmap map[string]int64
|
||||||
|
connPerMinute int64
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ConnectionRateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if h.connPerMinute == 0 {
|
||||||
|
h.handler.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
seen, ok := h.ipmap[host]
|
||||||
|
if !ok {
|
||||||
|
h.ipmap[host] = 1
|
||||||
|
h.mu.Unlock()
|
||||||
|
h.handler.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// rate limited
|
||||||
|
if seen > h.connPerMinute {
|
||||||
|
h.mu.Unlock()
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.ipmap[host] = seen + 1
|
||||||
|
h.mu.Unlock()
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Minute):
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.ipmap[host] = h.ipmap[host] - 1
|
||||||
|
if h.ipmap[host] <= 0 {
|
||||||
|
delete(h.ipmap, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
h.handler.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func limiterFromRateLimit(rateLimit int64) *rate.Limiter {
|
||||||
|
var limit rate.Limit
|
||||||
|
if rateLimit == 0 {
|
||||||
|
limit = rate.Inf
|
||||||
|
} else {
|
||||||
|
limit = rate.Every(time.Second / time.Duration(rateLimit))
|
||||||
|
}
|
||||||
|
return rate.NewLimiter(limit, stateRateLimitTokens)
|
||||||
}
|
}
|
||||||
|
@ -165,6 +165,13 @@ func (gw *Node) checkTimestamp(at time.Time) error {
|
|||||||
func (gw *Node) limit(ctx context.Context, tokens int) error {
|
func (gw *Node) limit(ctx context.Context, tokens int) error {
|
||||||
ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout)
|
ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
if perConnLimiter, ok := ctx2.Value("limiter").(*rate.Limiter); ok {
|
||||||
|
err := perConnLimiter.WaitN(ctx2, tokens)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connection limited. %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err := gw.rateLimiter.WaitN(ctx2, tokens)
|
err := gw.rateLimiter.WaitN(ctx2, tokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
stats.Record(ctx, metrics.RateLimitCount.M(1))
|
stats.Record(ctx, metrics.RateLimitCount.M(1))
|
||||||
@ -212,7 +219,7 @@ func (gw *Node) ChainHead(ctx context.Context) (*types.TipSet, error) {
|
|||||||
if err := gw.limit(ctx, chainRateLimitTokens); err != nil {
|
if err := gw.limit(ctx, chainRateLimitTokens); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// TODO: cache and invalidate cache when timestamp is up (or have internal ChainNotify)
|
|
||||||
return gw.target.ChainHead(ctx)
|
return gw.target.ChainHead(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user