From 70f7fa1c19d84827b882196eebb61dcca446b734 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Fri, 8 Apr 2016 11:33:25 -0700 Subject: [PATCH] Address code review comments --- stress/client/main.go | 134 +++++++++++++++--------------------------- 1 file changed, 46 insertions(+), 88 deletions(-) diff --git a/stress/client/main.go b/stress/client/main.go index 11ccbc0e..916d7f8b 100644 --- a/stress/client/main.go +++ b/stress/client/main.go @@ -1,3 +1,4 @@ +// client starts an interop client to do stress test and a metrics server to report qps. package main import ( @@ -20,70 +21,18 @@ import ( ) var ( - serverAddressesPtr = flag.String("server_addresses", "localhost:8080", "a list of server addresses") - testCasesPtr = flag.String("test_cases", "", "a list of test cases along with the relative weights") - testDurationSecsPtr = flag.Int("test_duration_secs", -1, "test duration in seconds") - numChannelsPerServerPtr = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") - numStubsPerChannelPtr = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") - metricsPortPtr = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") + serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses") + testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights") + testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds") + numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server") + numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server") + metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics") ) -// testCaseType is the type of test to be run -type testCaseType uint32 - -const ( - // emptyUnary is to make a unary RPC with empty request and response - emptyUnary testCaseType = 0 - - // largeUnary is to make a unary RPC with large payload in the request and response - largeUnary testCaseType = 1 - - // TODO largeCompressedUnary - - // clientStreaming is to make a client streaming RPC - clientStreaming testCaseType = 3 - - // serverStreaming is to make a server streaming RPC - serverStreaming testCaseType = 4 - - // emptyStream is to make a bi-directional streaming with zero message - emptyStream testCaseType = 5 - - // unknownTest means something is wrong - unknownTest testCaseType = 6 -) - -var testCaseNameTypeMap = map[string]testCaseType{ - "empty_unary": emptyUnary, - "large_unary": largeUnary, - "client_streaming": clientStreaming, - "server_streaming": serverStreaming, - "empty_stream": emptyStream, -} - -var testCaseTypeNameMap = map[testCaseType]string{ - emptyUnary: "empty_unary", - largeUnary: "large_unary", - clientStreaming: "client_streaming", - serverStreaming: "server_streaming", - emptyStream: "empty_stream", -} - -func (t testCaseType) String() string { - if s, ok := testCaseTypeNameMap[t]; ok { - return s - } - return "" -} - // testCaseWithWeight contains the test case type and its weight. type testCaseWithWeight struct { - testCase testCaseType - weight int -} - -func (test testCaseWithWeight) String() string { - return fmt.Sprintf("testCaseType: %d, Weight: %d", test.testCase, test.weight) + name string + weight int } // parseTestCases converts test case string to a list of struct testCaseWithWeight. @@ -95,11 +44,18 @@ func parseTestCases(testCaseString string) ([]testCaseWithWeight, error) { if len(temp) < 2 { return nil, fmt.Errorf("invalid test case with weight: %s", str) } - t, ok := testCaseNameTypeMap[temp[0]] - if !ok { + // Check if test case is supported. + switch temp[0] { + case + "empty_unary", + "large_unary", + "client_streaming", + "server_streaming", + "empty_stream": + default: return nil, fmt.Errorf("unknown test type: %s", temp[0]) } - testCases[i].testCase = t + testCases[i].name = temp[0] w, err := strconv.Atoi(temp[1]) if err != nil { return nil, fmt.Errorf("%v", err) @@ -125,16 +81,16 @@ func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTe return &weightedRandomTestSelector{tests, totalWeight} } -func (selector weightedRandomTestSelector) getNextTest() (testCaseType, error) { +func (selector weightedRandomTestSelector) getNextTest() (string, error) { random := rand.Intn(selector.totalWeight) var weightSofar int for _, test := range selector.tests { weightSofar += test.weight if random < weightSofar { - return test.testCase, nil + return test.name, nil } } - return unknownTest, fmt.Errorf("no test case selected by weightedRandomTestSelector") + return "", fmt.Errorf("no test case selected by weightedRandomTestSelector") } // gauge defines type for gauge. @@ -245,18 +201,18 @@ func (c *stressClient) mainLoop(gauge *gauge) { continue } switch test { - case emptyUnary: + case "empty_unary": interop.DoEmptyUnaryCall(c.interopClient) - case largeUnary: + case "large_unary": interop.DoLargeUnaryCall(c.interopClient) - case clientStreaming: + case "client_streaming": interop.DoClientStreaming(c.interopClient) - case serverStreaming: + case "server_streaming": interop.DoServerStreaming(c.interopClient) - case emptyStream: + case "empty_stream": interop.DoEmptyStream(c.interopClient) default: - grpclog.Fatal("Unsupported test case: %d", test) + grpclog.Fatalf("Unsupported test case: %d", test) } numCalls++ gauge.set(int64(float64(numCalls) / time.Since(timeStarted).Seconds())) @@ -264,16 +220,18 @@ func (c *stressClient) mainLoop(gauge *gauge) { } func logParameterInfo(addresses []string, tests []testCaseWithWeight) { - grpclog.Printf("server_addresses: %s", *serverAddressesPtr) - grpclog.Printf("test_cases: %s", *testCasesPtr) - grpclog.Printf("test_duration-secs: %d", *testDurationSecsPtr) - grpclog.Printf("num_channels_per_server: %d", *numChannelsPerServerPtr) - grpclog.Printf("num_stubs_per_channel: %d", *numStubsPerChannelPtr) - grpclog.Printf("metrics_port: %d", *metricsPortPtr) + grpclog.Printf("server_addresses: %s", *serverAddresses) + grpclog.Printf("test_cases: %s", *testCases) + grpclog.Printf("test_duration-secs: %d", *testDurationSecs) + grpclog.Printf("num_channels_per_server: %d", *numChannelsPerServer) + grpclog.Printf("num_stubs_per_channel: %d", *numStubsPerChannel) + grpclog.Printf("metrics_port: %d", *metricsPort) + grpclog.Println("addresses:") for i, addr := range addresses { grpclog.Printf("%d. %s\n", i+1, addr) } + grpclog.Println("tests:") for i, test := range tests { grpclog.Printf("%d. %v\n", i+1, test) } @@ -281,28 +239,28 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) { func main() { flag.Parse() - serverAddresses := strings.Split(*serverAddressesPtr, ",") - testCases, err := parseTestCases(*testCasesPtr) + addresses := strings.Split(*serverAddresses, ",") + tests, err := parseTestCases(*testCases) if err != nil { grpclog.Fatalf("%v\n", err) } - logParameterInfo(serverAddresses, testCases) - testSelector := newWeightedRandomTestSelector(testCases) + logParameterInfo(addresses, tests) + testSelector := newWeightedRandomTestSelector(tests) metricsServer := newMetricsServer() var wg sync.WaitGroup - wg.Add(len(serverAddresses) * *numChannelsPerServerPtr * *numStubsPerChannelPtr) + wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) var clientIndex int - for serverIndex, address := range serverAddresses { - for connIndex := 0; connIndex < *numChannelsPerServerPtr; connIndex++ { + for serverIndex, address := range addresses { + for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { conn, err := grpc.Dial(address, grpc.WithInsecure()) if err != nil { grpclog.Fatalf("Fail to dial: %v", err) } defer conn.Close() - for stubIndex := 0; stubIndex < *numStubsPerChannelPtr; stubIndex++ { + for stubIndex := 0; stubIndex < *numStubsPerChannel; stubIndex++ { clientIndex++ - client := newStressClient(clientIndex, address, conn, testSelector, *testDurationSecsPtr) + client := newStressClient(clientIndex, address, conn, testSelector, *testDurationSecs) buf := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, stubIndex+1) go func() { defer wg.Done() @@ -316,7 +274,7 @@ func main() { } } - go startServer(metricsServer, *metricsPortPtr) + go startServer(metricsServer, *metricsPort) wg.Wait() grpclog.Printf(" ===== ALL DONE ===== ")