diff --git a/stress/client/main.go b/stress/client/main.go index 0568c459..bb665e98 100644 --- a/stress/client/main.go +++ b/stress/client/main.go @@ -69,16 +69,16 @@ type testCaseWithWeight struct { } // parseTestCases converts test case string to a list of struct testCaseWithWeight. -func parseTestCases(testCaseString string) ([]testCaseWithWeight, error) { +func parseTestCases(testCaseString string) []testCaseWithWeight { testCaseStrings := strings.Split(testCaseString, ",") testCases := make([]testCaseWithWeight, len(testCaseStrings)) for i, str := range testCaseStrings { - temp := strings.Split(str, ":") - if len(temp) < 2 { - return nil, fmt.Errorf("invalid test case with weight: %s", str) + testCase := strings.Split(str, ":") + if len(testCase) != 2 { + panic(fmt.Sprintf("invalid test case with weight: %s", str)) } // Check if test case is supported. - switch temp[0] { + switch testCase[0] { case "empty_unary", "large_unary", @@ -86,16 +86,16 @@ func parseTestCases(testCaseString string) ([]testCaseWithWeight, error) { "server_streaming", "empty_stream": default: - return nil, fmt.Errorf("unknown test type: %s", temp[0]) + panic(fmt.Sprintf("unknown test type: %s", testCase[0])) } - testCases[i].name = temp[0] - w, err := strconv.Atoi(temp[1]) + testCases[i].name = testCase[0] + w, err := strconv.Atoi(testCase[1]) if err != nil { - return nil, fmt.Errorf("%v", err) + panic(fmt.Sprintf("%v", err)) } testCases[i].weight = w } - return testCases, nil + return testCases } // weightedRandomTestSelector defines a weighted random selector for test case types. @@ -114,25 +114,24 @@ func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTe return &weightedRandomTestSelector{tests, totalWeight} } -func (selector weightedRandomTestSelector) getNextTest() (string, error) { +func (selector weightedRandomTestSelector) getNextTest() string { random := rand.Intn(selector.totalWeight) var weightSofar int for _, test := range selector.tests { weightSofar += test.weight if random < weightSofar { - return test.name, nil + return test.name } } - return "", fmt.Errorf("no test case selected by weightedRandomTestSelector") + panic("no test case selected by weightedRandomTestSelector") } -// gauge defines type for gauge. +// gauge stores the qps of one interop client (one stub). type gauge struct { mutex sync.RWMutex val int64 } -// Set updates the gauge value func (g *gauge) set(v int64) { g.mutex.Lock() defer g.mutex.Unlock() @@ -145,9 +144,10 @@ func (g *gauge) get() int64 { return g.val } -// Server implements metrics server functions. +// server implements metrics server functions. type server struct { - mutex sync.RWMutex + mutex sync.RWMutex + // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge. gauges map[string]*gauge } @@ -162,8 +162,7 @@ func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.Metri defer s.mutex.RUnlock() for name, gauge := range s.gauges { - err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}) - if err != nil { + if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}); err != nil { return err } } @@ -181,19 +180,18 @@ func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*met return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) } -// CreateGauge creates a guage using the given name in metrics server -func (s *server) createGauge(name string) (*gauge, error) { +// createGauge creates a guage using the given name in metrics server. +func (s *server) createGauge(name string) *gauge { s.mutex.Lock() defer s.mutex.Unlock() - grpclog.Printf("create gauge: %s", name) if _, ok := s.gauges[name]; ok { // gauge already exists. - return nil, fmt.Errorf("gauge %s already exists", name) + panic(fmt.Sprintf("gauge %s already exists", name)) } var g gauge s.gauges[name] = &g - return &g, nil + return &g } func startServer(server *server, port int) { @@ -208,57 +206,35 @@ func startServer(server *server, port int) { } -// stressClient defines client for stress test. -type stressClient struct { - testID int - address string - selector *weightedRandomTestSelector - interopClient testpb.TestServiceClient - stop <-chan bool -} - -// newStressClient construct a new stressClient. -func newStressClient(id int, addr string, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) *stressClient { +// performRPCs uses weightedRandomTestSelector to select test case and runs the tests. +func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) { client := testpb.NewTestServiceClient(conn) - return &stressClient{testID: id, address: addr, selector: selector, interopClient: client, stop: stop} -} - -// mainLoop uses weightedRandomTestSelector to select test case and runs the tests. -func (c *stressClient) mainLoop(gauge *gauge) { var numCalls int64 - timeStarted := time.Now() + startTime := time.Now() for { - done := make(chan bool) + done := make(chan bool, 1) go func() { - test, err := c.selector.getNextTest() - if err != nil { - grpclog.Printf("%v", err) - done <- false - } + test := selector.getNextTest() switch test { case "empty_unary": - interop.DoEmptyUnaryCall(c.interopClient) + interop.DoEmptyUnaryCall(client) case "large_unary": - interop.DoLargeUnaryCall(c.interopClient) + interop.DoLargeUnaryCall(client) case "client_streaming": - interop.DoClientStreaming(c.interopClient) + interop.DoClientStreaming(client) case "server_streaming": - interop.DoServerStreaming(c.interopClient) + interop.DoServerStreaming(client) case "empty_stream": - interop.DoEmptyStream(c.interopClient) - default: - grpclog.Fatalf("Unsupported test case: %d", test) + interop.DoEmptyStream(client) } done <- true }() select { - case <-c.stop: + case <-stop: return - case r := <-done: - if r { - numCalls++ - gauge.set(int64(float64(numCalls) / time.Since(timeStarted).Seconds())) - } + case <-done: + numCalls++ + gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds())) } } } @@ -284,10 +260,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) { func main() { flag.Parse() addresses := strings.Split(*serverAddresses, ",") - tests, err := parseTestCases(*testCases) - if err != nil { - grpclog.Fatalf("%v\n", err) - } + tests := parseTestCases(*testCases) logParameterInfo(addresses, tests) testSelector := newWeightedRandomTestSelector(tests) metricsServer := newMetricsServer() @@ -296,7 +269,6 @@ func main() { wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) stop := make(chan bool) - var clientIndex int for serverIndex, address := range addresses { for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { conn, err := grpc.Dial(address, grpc.WithInsecure()) @@ -304,17 +276,12 @@ func main() { grpclog.Fatalf("Fail to dial: %v", err) } defer conn.Close() - for stubIndex := 0; stubIndex < *numStubsPerChannel; stubIndex++ { - clientIndex++ - client := newStressClient(clientIndex, address, conn, testSelector, stop) - buf := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, stubIndex+1) + for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ { + name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1) go func() { defer wg.Done() - if g, err := metricsServer.createGauge(buf); err != nil { - grpclog.Fatalf("%v", err) - } else { - client.mainLoop(g) - } + g := metricsServer.createGauge(name) + performRPCs(g, conn, testSelector, stop) }() } diff --git a/stress/metrics_client/main.go b/stress/metrics_client/main.go index fbce662c..983a8ff2 100644 --- a/stress/metrics_client/main.go +++ b/stress/metrics_client/main.go @@ -1,6 +1,6 @@ /* * - * Copyright 2014, Google Inc. + * Copyright 2016, Google Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -35,8 +35,8 @@ package main import ( "flag" + "fmt" "io" - "time" "golang.org/x/net/context" "google.golang.org/grpc" @@ -45,54 +45,53 @@ import ( ) var ( - metricsServerAddressPtr = flag.String("metrics_server_address", "", "The metrics server addresses in the fomrat <hostname>:<port>") - totalOnlyPtr = flag.Bool("total_only", false, "If true, this prints only the total value of all gauges") + metricsServerAddress = flag.String("metrics_server_address", "", "The metrics server addresses in the fomrat <hostname>:<port>") + totalOnly = flag.Bool("total_only", false, "If true, this prints only the total value of all gauges") ) -const timeoutSeconds = 10 - func printMetrics(client metricspb.MetricsServiceClient, totalOnly bool) { - ctx, _ := context.WithTimeout(context.Background(), timeoutSeconds*time.Second) - stream, err := client.GetAllGauges(ctx, &metricspb.EmptyMessage{}) + stream, err := client.GetAllGauges(context.Background(), &metricspb.EmptyMessage{}) if err != nil { grpclog.Fatalf("failed to call GetAllGuages: %v", err) } - var overallQPS int64 - var rpcStatus error + var ( + overallQPS int64 + rpcStatus error + ) for { gaugeResponse, err := stream.Recv() if err != nil { rpcStatus = err break } - if _, ok := gaugeResponse.GetValue().(*metricspb.GaugeResponse_LongValue); ok { - if !totalOnly { - grpclog.Printf("%s: %d", gaugeResponse.Name, gaugeResponse.GetLongValue()) - } - overallQPS += gaugeResponse.GetLongValue() - } else { - grpclog.Printf("gauge %s is not a long value", gaugeResponse.Name) + if _, ok := gaugeResponse.GetValue().(*metricspb.GaugeResponse_LongValue); !ok { + panic(fmt.Sprintf("gauge %s is not a long value", gaugeResponse.Name)) } + v := gaugeResponse.GetLongValue() + if !totalOnly { + grpclog.Printf("%s: %d", gaugeResponse.Name, v) + } + overallQPS += v } - grpclog.Printf("overall qps: %d", overallQPS) if rpcStatus != io.EOF { grpclog.Fatalf("failed to finish server streaming: %v", rpcStatus) } + grpclog.Printf("overall qps: %d", overallQPS) } func main() { flag.Parse() - if len(*metricsServerAddressPtr) == 0 { - grpclog.Fatalf("Cannot connect to the Metrics server. Please pass the address of the metrics server to connect to via the 'metrics_server_address' flag") + if *metricsServerAddress == "" { + grpclog.Fatalf("Metrics server address is empty.") } - conn, err := grpc.Dial(*metricsServerAddressPtr, grpc.WithInsecure()) + conn, err := grpc.Dial(*metricsServerAddress, grpc.WithInsecure()) if err != nil { grpclog.Fatalf("cannot connect to metrics server: %v", err) } defer conn.Close() c := metricspb.NewMetricsServiceClient(conn) - printMetrics(c, *totalOnlyPtr) + printMetrics(c, *totalOnly) }