diff --git a/cmd/lotus-gateway/main.go b/cmd/lotus-gateway/main.go index 87dea40d2..e315af404 100644 --- a/cmd/lotus-gateway/main.go +++ b/cmd/lotus-gateway/main.go @@ -138,11 +138,21 @@ var runCmd = &cli.Command{ Usage: "rate-limit API calls. Use 0 to disable", Value: 0, }, + &cli.Int64Flag{ + Name: "per-conn-rate-limit", + Usage: "rate-limit API calls per each connection. Use 0 to disable", + Value: 0, + }, &cli.DurationFlag{ Name: "rate-limit-timeout", Usage: "the maximum time to wait for the rate limter before returning an error to clients", Value: gateway.DefaultRateLimitTimeout, }, + &cli.Int64Flag{ + Name: "conn-per-minute", + Usage: "The number of incomming connections to accept from a single IP per minute. Use 0 to disable", + Value: 0, + }, }, Action: func(cctx *cli.Context) error { log.Info("Starting lotus gateway") @@ -165,7 +175,9 @@ var runCmd = &cli.Command{ address = cctx.String("listen") waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit")) rateLimit = cctx.Int64("rate-limit") + perConnRateLimit = cctx.Int64("per-conn-rate-limit") rateLimitTimeout = cctx.Duration("rate-limit-timeout") + connPerMinute = cctx.Int64("conn-per-minute") ) serverOptions := make([]jsonrpc.ServerOption, 0) @@ -186,7 +198,7 @@ var runCmd = &cli.Command{ } gwapi := gateway.NewNode(api, lookbackCap, waitLookback, rateLimit, rateLimitTimeout) - h, err := gateway.Handler(gwapi, serverOptions...) + h, err := gateway.Handler(gwapi, perConnRateLimit, connPerMinute, serverOptions...) if err != nil { return xerrors.Errorf("failed to set up gateway HTTP handler") } diff --git a/gateway/handler.go b/gateway/handler.go index f8da5a5e1..e2bc0a9a2 100644 --- a/gateway/handler.go +++ b/gateway/handler.go @@ -1,7 +1,11 @@ package gateway import ( + "context" + "net" "net/http" + "sync" + "time" "contrib.go.opencensus.io/exporter/prometheus" "github.com/filecoin-project/go-jsonrpc" @@ -11,10 +15,11 @@ import ( "github.com/filecoin-project/lotus/metrics/proxy" "github.com/gorilla/mux" promclient "github.com/prometheus/client_golang/prometheus" + "golang.org/x/time/rate" ) // Handler returns a gateway http.Handler, to be mounted as-is on the server. -func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error) { +func Handler(a api.Gateway, rateLimit int64, connPerMinute int64, opts ...jsonrpc.ServerOption) (http.Handler, error) { m := mux.NewRouter() serveRpc := func(path string, hnd interface{}) { @@ -44,5 +49,95 @@ func Handler(a api.Gateway, opts ...jsonrpc.ServerOption) (http.Handler, error) Next: mux.ServeHTTP, }*/ - return m, nil + rlh := NewRateLimiterHandler(m, rateLimit) + clh := NewConnectionRateLimiterHandler(rlh, connPerMinute) + return clh, nil +} + +func NewRateLimiterHandler(handler http.Handler, rateLimit int64) *RateLimiterHandler { + limiter := limiterFromRateLimit(rateLimit) + + return &RateLimiterHandler{ + handler: handler, + limiter: limiter, + } +} + +// Adds a rate limiter to the request context for per-connection rate limiting +type RateLimiterHandler struct { + handler http.Handler + limiter *rate.Limiter +} + +func (h RateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r2 := r.WithContext(context.WithValue(r.Context(), "limiter", h.limiter)) + h.handler.ServeHTTP(w, r2) +} + +// this blocks new connections if there have already been too many. +func NewConnectionRateLimiterHandler(handler http.Handler, connPerMinute int64) *ConnectionRateLimiterHandler { + ipmap := make(map[string]int64) + return &ConnectionRateLimiterHandler{ + ipmap: ipmap, + connPerMinute: connPerMinute, + handler: handler, + } +} + +type ConnectionRateLimiterHandler struct { + mu sync.Mutex + ipmap map[string]int64 + connPerMinute int64 + handler http.Handler +} + +func (h *ConnectionRateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h.connPerMinute == 0 { + h.handler.ServeHTTP(w, r) + return + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + h.mu.Lock() + seen, ok := h.ipmap[host] + if !ok { + h.ipmap[host] = 1 + h.mu.Unlock() + h.handler.ServeHTTP(w, r) + return + } + // rate limited + if seen > h.connPerMinute { + h.mu.Unlock() + w.WriteHeader(http.StatusTooManyRequests) + return + } + h.ipmap[host] = seen + 1 + h.mu.Unlock() + go func() { + select { + case <-time.After(time.Minute): + h.mu.Lock() + defer h.mu.Unlock() + h.ipmap[host] = h.ipmap[host] - 1 + if h.ipmap[host] <= 0 { + delete(h.ipmap, host) + } + } + }() + h.handler.ServeHTTP(w, r) +} + +func limiterFromRateLimit(rateLimit int64) *rate.Limiter { + var limit rate.Limit + if rateLimit == 0 { + limit = rate.Inf + } else { + limit = rate.Every(time.Second / time.Duration(rateLimit)) + } + return rate.NewLimiter(limit, stateRateLimitTokens) } diff --git a/gateway/node.go b/gateway/node.go index 80c6219e1..2e9d8b750 100644 --- a/gateway/node.go +++ b/gateway/node.go @@ -165,6 +165,13 @@ func (gw *Node) checkTimestamp(at time.Time) error { func (gw *Node) limit(ctx context.Context, tokens int) error { ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout) defer cancel() + if perConnLimiter, ok := ctx2.Value("limiter").(*rate.Limiter); ok { + err := perConnLimiter.WaitN(ctx2, tokens) + if err != nil { + return fmt.Errorf("connection limited. %w", err) + } + } + err := gw.rateLimiter.WaitN(ctx2, tokens) if err != nil { stats.Record(ctx, metrics.RateLimitCount.M(1)) @@ -212,7 +219,7 @@ func (gw *Node) ChainHead(ctx context.Context) (*types.TipSet, error) { if err := gw.limit(ctx, chainRateLimitTokens); err != nil { return nil, err } - // TODO: cache and invalidate cache when timestamp is up (or have internal ChainNotify) + return gw.target.ChainHead(ctx) }