diff --git a/cmd/lotus-bench/main.go b/cmd/lotus-bench/main.go index 12d310b65..883f27a42 100644 --- a/cmd/lotus-bench/main.go +++ b/cmd/lotus-bench/main.go @@ -106,6 +106,7 @@ func main() { sealBenchCmd, simpleCmd, importBenchCmd, + rpcCmd, }, } diff --git a/cmd/lotus-bench/rpc.go b/cmd/lotus-bench/rpc.go new file mode 100644 index 000000000..5da784c6e --- /dev/null +++ b/cmd/lotus-bench/rpc.go @@ -0,0 +1,576 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "sort" + "strconv" + "strings" + "sync" + "text/tabwriter" + "time" + + "github.com/urfave/cli/v2" +) + +var rpcCmd = &cli.Command{ + Name: "rpc", + Usage: "Runs a concurrent stress test on one or more rpc methods and prints the performance metrics including latency distribution and histogram", + Description: `This benchmark is designed to stress test the rpc methods of a lotus node so that we can simulate real world usage and measure the performance of rpc methods on the node. + +This benchmark has the following features: +* Can query each method both sequentially and concurrently +* Supports rate limiting +* Can query multiple different endpoints at once (supporting different concurrency level and rate limiting for each method) +* Gives a nice reporting summary of the stress testing of each method (including latency distribution, histogram and more) +* Easy to use + +To use this benchmark you must specify the rpc methods you want to test using the --method options, the format of it is: + + --method=NAME[:CONCURRENCY][:QPS][:PARAMS] where only NAME is required. + +Here are some real examples: + lotus-bench rpc --method='eth_chainId' // run eth_chainId with default concurrency and qps + lotus-bench rpc --method='eth_chainId:3' // override concurrency to 3 + lotus-bench rpc --method='eth_chainId::100' // override to 100 qps while using default concurrency + lotus-bench rpc --method='eth_chainId:3:100' // run using 3 workers but limit to 100 qps + lotus-bench rpc --method='eth_getTransactionCount:::["0xd4c70007F3F502f212c7e6794b94C06F36173B36", "latest"]' // run using optional params while using default concurrency and qps + lotus-bench rpc --method='eth_chainId' --method='eth_getTransactionCount:10:0:["0xd4c70007F3F502f212c7e6794b94C06F36173B36", "latest"]' // run multiple methods at once + +NOTE: The last two examples will not work until we upgrade urfave dependency (tracked in https://github.com/urfave/cli/issues/1734)`, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "endpoint", + Value: "http://127.0.0.1:1234/rpc/v1", + Usage: "The rpc endpoint to benchmark", + }, + &cli.DurationFlag{ + Name: "duration", + Value: 60 * time.Second, + Usage: "Duration of benchmark in seconds", + }, + &cli.IntFlag{ + Name: "concurrency", + Value: 10, + Usage: "How many workers should be used per rpc method (can be overridden per method)", + }, + &cli.IntFlag{ + Name: "qps", + Value: 0, + Usage: "How many requests per second should be sent per rpc method (can be overridden per method), a value of 0 means no limit", + }, + &cli.StringSliceFlag{ + Name: "method", + Usage: `Method to benchmark, you can specify multiple methods by repeating this flag. You can also specify method specific options to set the concurrency and qps for each method (see usage). +`, + }, + &cli.DurationFlag{ + Name: "watch", + Value: 0 * time.Second, + Usage: "If >0 then generates reports every N seconds (only supports linux/unix)", + }, + &cli.BoolFlag{ + Name: "print-response", + Value: false, + Usage: "print the response of each request", + }, + }, + Action: func(cctx *cli.Context) error { + if len(cctx.StringSlice("method")) == 0 { + return errors.New("you must specify and least one method to benchmark") + } + + var rpcMethods []*RPCMethod + for _, str := range cctx.StringSlice("method") { + entries := strings.SplitN(str, ":", 4) + if len(entries) == 0 { + return errors.New("invalid method format") + } + + // check if concurrency was specified + concurrency := cctx.Int("concurrency") + if len(entries) > 1 { + if len(entries[1]) > 0 { + var err error + concurrency, err = strconv.Atoi(entries[1]) + if err != nil { + return fmt.Errorf("could not parse concurrency value from method %s: %v", entries[0], err) + } + } + } + + // check if qps was specified + qps := cctx.Int("qps") + if len(entries) > 2 { + if len(entries[2]) > 0 { + var err error + qps, err = strconv.Atoi(entries[2]) + if err != nil { + return fmt.Errorf("could not parse qps value from method %s: %v", entries[0], err) + } + } + } + + // check if params was specified + params := "[]" + if len(entries) > 3 { + params = entries[3] + } + + rpcMethods = append(rpcMethods, &RPCMethod{ + w: os.Stdout, + uri: cctx.String("endpoint"), + method: entries[0], + concurrency: concurrency, + qps: qps, + params: params, + printResp: cctx.Bool("print-response"), + }) + } + + // terminate early on ctrl+c + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + <-c + fmt.Println("Received interrupt, stopping...") + for _, method := range rpcMethods { + method.Stop() + } + }() + + // stop all threads after duration + go func() { + time.Sleep(cctx.Duration("duration")) + for _, e := range rpcMethods { + e.Stop() + } + }() + + // start all threads + var wg sync.WaitGroup + wg.Add(len(rpcMethods)) + + for _, e := range rpcMethods { + go func(e *RPCMethod) { + defer wg.Done() + err := e.Run() + if err != nil { + fmt.Printf("error running rpc method: %v\n", err) + } + }(e) + } + + // if watch is set then print a report every N seconds + var progressCh chan struct{} + if cctx.Duration("watch") > 0 { + progressCh = make(chan struct{}, 1) + go func(progressCh chan struct{}) { + ticker := time.NewTicker(cctx.Duration("watch")) + for { + clearAndPrintReport := func() { + // clear the screen move the curser to the top left + fmt.Print("\033[2J") + fmt.Printf("\033[%d;%dH", 1, 1) + for i, e := range rpcMethods { + e.Report() + if i < len(rpcMethods)-1 { + fmt.Println() + } + } + } + select { + case <-ticker.C: + clearAndPrintReport() + case <-progressCh: + clearAndPrintReport() + return + } + } + }(progressCh) + } + + wg.Wait() + + if progressCh != nil { + // wait for the watch go routine to return + progressCh <- struct{}{} + + // no need to print the report again + return nil + } + + // print the report for each endpoint + for i, e := range rpcMethods { + e.Report() + if i < len(rpcMethods)-1 { + fmt.Println() + } + } + + return nil + }, +} + +// RPCMethod handles the benchmarking of a single endpoint method. +type RPCMethod struct { + w io.Writer + // the endpoint uri + uri string + // the rpc method we want to benchmark + method string + // the number of concurrent requests to make to this endpoint + concurrency int + // if >0 then limit to qps is the max number of requests per second to make to this endpoint (0 = no limit) + qps int + // many endpoints require specific parameters to be passed + params string + // whether or not to print the response of each request (useful for debugging) + printResp bool + // instruct the worker go routines to stop + stopCh chan struct{} + // when the endpoint bencharking started + start time.Time + // results channel is used by the workers to send results to the reporter + results chan *result + // reporter handles reading the results from workers and printing the report statistics + reporter *Reporter +} + +// result is the result of a single rpc method request. +type result struct { + err error + statusCode *int + duration time.Duration +} + +func (rpc *RPCMethod) Run() error { + client := &http.Client{ + Timeout: 0, + } + + var wg sync.WaitGroup + wg.Add(rpc.concurrency) + + rpc.results = make(chan *result, rpc.concurrency*1_000) + rpc.stopCh = make(chan struct{}, rpc.concurrency) + + go func() { + rpc.reporter = NewReporter(rpc.results, rpc.w) + rpc.reporter.Run() + }() + + rpc.start = time.Now() + + // throttle the number of requests per second + var qpsTicker *time.Ticker + if rpc.qps > 0 { + qpsTicker = time.NewTicker(time.Second / time.Duration(rpc.qps)) + } + + for i := 0; i < rpc.concurrency; i++ { + go func() { + rpc.startWorker(client, qpsTicker) + wg.Done() + }() + } + wg.Wait() + + // close the results channel so reporter will stop + close(rpc.results) + + // wait until the reporter is done + <-rpc.reporter.doneCh + + return nil +} + +func (rpc *RPCMethod) startWorker(client *http.Client, qpsTicker *time.Ticker) { + for { + // check if we should stop + select { + case <-rpc.stopCh: + return + default: + } + + // wait for the next tick if we are rate limiting this endpoint + if qpsTicker != nil { + <-qpsTicker.C + } + + req, err := rpc.buildRequest() + if err != nil { + log.Fatalln(err) + } + + start := time.Now() + + var statusCode *int + + // send request the endpoint + resp, err := client.Do(req) + if err != nil { + err = fmt.Errorf("HTTP error: %s", err.Error()) + } else { + statusCode = &resp.StatusCode + + // there was not a HTTP error but we need to still check the json response for errrors + var data []byte + data, err = io.ReadAll(resp.Body) + if err != nil { + log.Fatalln(err) + } + + // we are only interested if it has the error field in the response + type respData struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + + // unmarshal the response into a struct so we can check for errors + var d respData + err = json.Unmarshal(data, &d) + if err != nil { + log.Fatalln(err) + } + + // if the response has an error json message then it should be considered an error just like any http error + if len(d.Error.Message) > 0 { + // truncate the error message if it is too long + if len(d.Error.Message) > 1000 { + d.Error.Message = d.Error.Message[:1000] + "..." + } + // remove newlines from the error message so we don't screw up the report + d.Error.Message = strings.ReplaceAll(d.Error.Message, "\n", "") + + err = fmt.Errorf("JSON error: code:%d, message:%s", d.Error.Code, d.Error.Message) + } + + if rpc.printResp { + fmt.Printf("[%s] %s", rpc.method, string(data)) + } + + resp.Body.Close() //nolint:errcheck + } + + rpc.results <- &result{ + statusCode: statusCode, + err: err, + duration: time.Since(start), + } + } +} + +func (rpc *RPCMethod) buildRequest() (*http.Request, error) { + jreq, err := json.Marshal(struct { + Jsonrpc string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + }{ + Jsonrpc: "2.0", + Method: rpc.method, + Params: json.RawMessage(rpc.params), + ID: 0, + }) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", rpc.uri, bytes.NewReader(jreq)) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + + return req, nil +} + +func (rpc *RPCMethod) Stop() { + for i := 0; i < rpc.concurrency; i++ { + rpc.stopCh <- struct{}{} + } +} + +func (rpc *RPCMethod) Report() { + total := time.Since(rpc.start) + fmt.Fprintf(rpc.w, "[%s]:\n", rpc.method) + fmt.Fprintf(rpc.w, "- Options:\n") + fmt.Fprintf(rpc.w, " - concurrency: %d\n", rpc.concurrency) + fmt.Fprintf(rpc.w, " - params: %s\n", rpc.params) + fmt.Fprintf(rpc.w, " - qps: %d\n", rpc.qps) + rpc.reporter.Print(total, rpc.w) +} + +// Reporter reads the results from the workers through the results channel and aggregates the results. +type Reporter struct { + // write the report to this writer + w io.Writer + // the reporter read the results from this channel + results chan *result + // doneCh is used to signal that the reporter has finished reading the results (channel has closed) + doneCh chan bool + + // lock protect the following fields during critical sections (if --watch was specified) + lock sync.Mutex + // the latencies of all requests + latencies []int64 + // the number of requests that returned each status code + statusCodes map[int]int + // the number of errors that occurred + errors map[string]int +} + +func NewReporter(results chan *result, w io.Writer) *Reporter { + return &Reporter{ + w: w, + results: results, + doneCh: make(chan bool, 1), + statusCodes: make(map[int]int), + errors: make(map[string]int), + } +} + +func (r *Reporter) Run() { + for res := range r.results { + r.lock.Lock() + + r.latencies = append(r.latencies, res.duration.Milliseconds()) + + if res.statusCode != nil { + r.statusCodes[*res.statusCode]++ + } + + if res.err != nil { + if len(r.errors) < 1_000_000 { + r.errors[res.err.Error()]++ + } else { + // we don't want to store too many errors in memory + r.errors["hidden"]++ + } + } else { + r.errors["nil"]++ + } + + r.lock.Unlock() + } + + r.doneCh <- true +} + +func (r *Reporter) Print(elapsed time.Duration, w io.Writer) { + r.lock.Lock() + defer r.lock.Unlock() + + nrReq := int64(len(r.latencies)) + if nrReq == 0 { + fmt.Println("No requests were made") + return + } + + // we need to sort the latencies slice to calculate the percentiles + sort.Slice(r.latencies, func(i, j int) bool { + return r.latencies[i] < r.latencies[j] + }) + + var totalLatency int64 = 0 + for _, latency := range r.latencies { + totalLatency += latency + } + + fmt.Fprintf(w, "- Total Requests: %d\n", nrReq) + fmt.Fprintf(w, "- Total Duration: %dms\n", elapsed.Milliseconds()) + fmt.Fprintf(w, "- Requests/sec: %f\n", float64(nrReq)/elapsed.Seconds()) + fmt.Fprintf(w, "- Avg latency: %dms\n", totalLatency/nrReq) + fmt.Fprintf(w, "- Median latency: %dms\n", r.latencies[nrReq/2]) + fmt.Fprintf(w, "- Latency distribution:\n") + percentiles := []float64{0.1, 0.5, 0.9, 0.95, 0.99, 0.999} + for _, p := range percentiles { + idx := int64(p * float64(nrReq)) + fmt.Fprintf(w, " %s%% in %dms\n", fmt.Sprintf("%.2f", p*100.0), r.latencies[idx]) + } + + // create a simple histogram with 10 buckets spanning the range of latency + // into equal ranges + // + nrBucket := 10 + buckets := make([]Bucket, nrBucket) + latencyRange := r.latencies[len(r.latencies)-1] + bucketRange := latencyRange / int64(nrBucket) + + // mark the end of each bucket + for i := 0; i < nrBucket; i++ { + buckets[i].start = int64(i) * bucketRange + buckets[i].end = buckets[i].start + bucketRange + // extend the last bucked by any remaning range caused by the integer division + if i == nrBucket-1 { + buckets[i].end = latencyRange + } + } + + // count the number of requests in each bucket + currBucket := 0 + for i := 0; i < len(r.latencies); { + if r.latencies[i] <= buckets[currBucket].end { + buckets[currBucket].cnt++ + i++ + } else { + currBucket++ + } + } + + // print the histogram using a tabwriter which will align the columns nicely + fmt.Fprintf(w, "- Histogram:\n") + const padding = 2 + tabWriter := tabwriter.NewWriter(w, 0, 0, padding, ' ', tabwriter.AlignRight|tabwriter.Debug) + for i := 0; i < nrBucket; i++ { + ratio := float64(buckets[i].cnt) / float64(nrReq) + bars := strings.Repeat("#", int(ratio*100)) + fmt.Fprintf(tabWriter, " %d-%dms\t%d\t%s (%s%%)\n", buckets[i].start, buckets[i].end, buckets[i].cnt, bars, fmt.Sprintf("%.2f", ratio*100)) + } + tabWriter.Flush() //nolint:errcheck + + fmt.Fprintf(w, "- Status codes:\n") + for code, cnt := range r.statusCodes { + fmt.Fprintf(w, " [%d]: %d\n", code, cnt) + } + + // print the 10 most occurring errors (in case error values are not unique) + // + type kv struct { + err string + cnt int + } + var sortedErrors []kv + for err, cnt := range r.errors { + sortedErrors = append(sortedErrors, kv{err, cnt}) + } + sort.Slice(sortedErrors, func(i, j int) bool { + return sortedErrors[i].cnt > sortedErrors[j].cnt + }) + fmt.Fprintf(w, "- Errors (top 10):\n") + for i, se := range sortedErrors { + if i > 10 { + break + } + fmt.Fprintf(w, " [%s]: %d\n", se.err, se.cnt) + } +} + +type Bucket struct { + start int64 + // the end value of the bucket + end int64 + // how many entries are in the bucket + cnt int +}