Address code review comments

This commit is contained in:
Menghan Li
2016-04-15 11:50:25 -07:00
parent 9530d84aba
commit 8c5cde66aa
2 changed files with 62 additions and 96 deletions

View File

@ -69,16 +69,16 @@ type testCaseWithWeight struct {
} }
// parseTestCases converts test case string to a list of struct testCaseWithWeight. // parseTestCases converts test case string to a list of struct testCaseWithWeight.
func parseTestCases(testCaseString string) ([]testCaseWithWeight, error) { func parseTestCases(testCaseString string) []testCaseWithWeight {
testCaseStrings := strings.Split(testCaseString, ",") testCaseStrings := strings.Split(testCaseString, ",")
testCases := make([]testCaseWithWeight, len(testCaseStrings)) testCases := make([]testCaseWithWeight, len(testCaseStrings))
for i, str := range testCaseStrings { for i, str := range testCaseStrings {
temp := strings.Split(str, ":") testCase := strings.Split(str, ":")
if len(temp) < 2 { if len(testCase) != 2 {
return nil, fmt.Errorf("invalid test case with weight: %s", str) panic(fmt.Sprintf("invalid test case with weight: %s", str))
} }
// Check if test case is supported. // Check if test case is supported.
switch temp[0] { switch testCase[0] {
case case
"empty_unary", "empty_unary",
"large_unary", "large_unary",
@ -86,16 +86,16 @@ func parseTestCases(testCaseString string) ([]testCaseWithWeight, error) {
"server_streaming", "server_streaming",
"empty_stream": "empty_stream":
default: default:
return nil, fmt.Errorf("unknown test type: %s", temp[0]) panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
} }
testCases[i].name = temp[0] testCases[i].name = testCase[0]
w, err := strconv.Atoi(temp[1]) w, err := strconv.Atoi(testCase[1])
if err != nil { if err != nil {
return nil, fmt.Errorf("%v", err) panic(fmt.Sprintf("%v", err))
} }
testCases[i].weight = w testCases[i].weight = w
} }
return testCases, nil return testCases
} }
// weightedRandomTestSelector defines a weighted random selector for test case types. // weightedRandomTestSelector defines a weighted random selector for test case types.
@ -114,25 +114,24 @@ func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTe
return &weightedRandomTestSelector{tests, totalWeight} return &weightedRandomTestSelector{tests, totalWeight}
} }
func (selector weightedRandomTestSelector) getNextTest() (string, error) { func (selector weightedRandomTestSelector) getNextTest() string {
random := rand.Intn(selector.totalWeight) random := rand.Intn(selector.totalWeight)
var weightSofar int var weightSofar int
for _, test := range selector.tests { for _, test := range selector.tests {
weightSofar += test.weight weightSofar += test.weight
if random < weightSofar { if random < weightSofar {
return test.name, nil return test.name
} }
} }
return "", fmt.Errorf("no test case selected by weightedRandomTestSelector") panic("no test case selected by weightedRandomTestSelector")
} }
// gauge defines type for gauge. // gauge stores the qps of one interop client (one stub).
type gauge struct { type gauge struct {
mutex sync.RWMutex mutex sync.RWMutex
val int64 val int64
} }
// Set updates the gauge value
func (g *gauge) set(v int64) { func (g *gauge) set(v int64) {
g.mutex.Lock() g.mutex.Lock()
defer g.mutex.Unlock() defer g.mutex.Unlock()
@ -145,9 +144,10 @@ func (g *gauge) get() int64 {
return g.val return g.val
} }
// Server implements metrics server functions. // server implements metrics server functions.
type server struct { type server struct {
mutex sync.RWMutex mutex sync.RWMutex
// gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
gauges map[string]*gauge gauges map[string]*gauge
} }
@ -162,8 +162,7 @@ func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.Metri
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
for name, gauge := range s.gauges { for name, gauge := range s.gauges {
err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}) if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}); err != nil {
if err != nil {
return err return err
} }
} }
@ -181,19 +180,18 @@ func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*met
return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
} }
// CreateGauge creates a guage using the given name in metrics server // createGauge creates a guage using the given name in metrics server.
func (s *server) createGauge(name string) (*gauge, error) { func (s *server) createGauge(name string) *gauge {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
grpclog.Printf("create gauge: %s", name)
if _, ok := s.gauges[name]; ok { if _, ok := s.gauges[name]; ok {
// gauge already exists. // gauge already exists.
return nil, fmt.Errorf("gauge %s already exists", name) panic(fmt.Sprintf("gauge %s already exists", name))
} }
var g gauge var g gauge
s.gauges[name] = &g s.gauges[name] = &g
return &g, nil return &g
} }
func startServer(server *server, port int) { func startServer(server *server, port int) {
@ -208,57 +206,35 @@ func startServer(server *server, port int) {
} }
// stressClient defines client for stress test. // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
type stressClient struct { func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
testID int
address string
selector *weightedRandomTestSelector
interopClient testpb.TestServiceClient
stop <-chan bool
}
// newStressClient construct a new stressClient.
func newStressClient(id int, addr string, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) *stressClient {
client := testpb.NewTestServiceClient(conn) client := testpb.NewTestServiceClient(conn)
return &stressClient{testID: id, address: addr, selector: selector, interopClient: client, stop: stop}
}
// mainLoop uses weightedRandomTestSelector to select test case and runs the tests.
func (c *stressClient) mainLoop(gauge *gauge) {
var numCalls int64 var numCalls int64
timeStarted := time.Now() startTime := time.Now()
for { for {
done := make(chan bool) done := make(chan bool, 1)
go func() { go func() {
test, err := c.selector.getNextTest() test := selector.getNextTest()
if err != nil {
grpclog.Printf("%v", err)
done <- false
}
switch test { switch test {
case "empty_unary": case "empty_unary":
interop.DoEmptyUnaryCall(c.interopClient) interop.DoEmptyUnaryCall(client)
case "large_unary": case "large_unary":
interop.DoLargeUnaryCall(c.interopClient) interop.DoLargeUnaryCall(client)
case "client_streaming": case "client_streaming":
interop.DoClientStreaming(c.interopClient) interop.DoClientStreaming(client)
case "server_streaming": case "server_streaming":
interop.DoServerStreaming(c.interopClient) interop.DoServerStreaming(client)
case "empty_stream": case "empty_stream":
interop.DoEmptyStream(c.interopClient) interop.DoEmptyStream(client)
default:
grpclog.Fatalf("Unsupported test case: %d", test)
} }
done <- true done <- true
}() }()
select { select {
case <-c.stop: case <-stop:
return return
case r := <-done: case <-done:
if r {
numCalls++ numCalls++
gauge.set(int64(float64(numCalls) / time.Since(timeStarted).Seconds())) gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
}
} }
} }
} }
@ -284,10 +260,7 @@ func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
func main() { func main() {
flag.Parse() flag.Parse()
addresses := strings.Split(*serverAddresses, ",") addresses := strings.Split(*serverAddresses, ",")
tests, err := parseTestCases(*testCases) tests := parseTestCases(*testCases)
if err != nil {
grpclog.Fatalf("%v\n", err)
}
logParameterInfo(addresses, tests) logParameterInfo(addresses, tests)
testSelector := newWeightedRandomTestSelector(tests) testSelector := newWeightedRandomTestSelector(tests)
metricsServer := newMetricsServer() metricsServer := newMetricsServer()
@ -296,7 +269,6 @@ func main() {
wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel) wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
stop := make(chan bool) stop := make(chan bool)
var clientIndex int
for serverIndex, address := range addresses { for serverIndex, address := range addresses {
for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ { for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
conn, err := grpc.Dial(address, grpc.WithInsecure()) conn, err := grpc.Dial(address, grpc.WithInsecure())
@ -304,17 +276,12 @@ func main() {
grpclog.Fatalf("Fail to dial: %v", err) grpclog.Fatalf("Fail to dial: %v", err)
} }
defer conn.Close() defer conn.Close()
for stubIndex := 0; stubIndex < *numStubsPerChannel; stubIndex++ { for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
clientIndex++ name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
client := newStressClient(clientIndex, address, conn, testSelector, stop)
buf := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, stubIndex+1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if g, err := metricsServer.createGauge(buf); err != nil { g := metricsServer.createGauge(name)
grpclog.Fatalf("%v", err) performRPCs(g, conn, testSelector, stop)
} else {
client.mainLoop(g)
}
}() }()
} }

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2014, Google Inc. * Copyright 2016, Google Inc.
* All rights reserved. * All rights reserved.
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -35,8 +35,8 @@ package main
import ( import (
"flag" "flag"
"fmt"
"io" "io"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -45,54 +45,53 @@ import (
) )
var ( var (
metricsServerAddressPtr = flag.String("metrics_server_address", "", "The metrics server addresses in the fomrat <hostname>:<port>") metricsServerAddress = flag.String("metrics_server_address", "", "The metrics server addresses in the fomrat <hostname>:<port>")
totalOnlyPtr = flag.Bool("total_only", false, "If true, this prints only the total value of all gauges") totalOnly = flag.Bool("total_only", false, "If true, this prints only the total value of all gauges")
) )
const timeoutSeconds = 10
func printMetrics(client metricspb.MetricsServiceClient, totalOnly bool) { func printMetrics(client metricspb.MetricsServiceClient, totalOnly bool) {
ctx, _ := context.WithTimeout(context.Background(), timeoutSeconds*time.Second) stream, err := client.GetAllGauges(context.Background(), &metricspb.EmptyMessage{})
stream, err := client.GetAllGauges(ctx, &metricspb.EmptyMessage{})
if err != nil { if err != nil {
grpclog.Fatalf("failed to call GetAllGuages: %v", err) grpclog.Fatalf("failed to call GetAllGuages: %v", err)
} }
var overallQPS int64 var (
var rpcStatus error overallQPS int64
rpcStatus error
)
for { for {
gaugeResponse, err := stream.Recv() gaugeResponse, err := stream.Recv()
if err != nil { if err != nil {
rpcStatus = err rpcStatus = err
break break
} }
if _, ok := gaugeResponse.GetValue().(*metricspb.GaugeResponse_LongValue); ok { if _, ok := gaugeResponse.GetValue().(*metricspb.GaugeResponse_LongValue); !ok {
panic(fmt.Sprintf("gauge %s is not a long value", gaugeResponse.Name))
}
v := gaugeResponse.GetLongValue()
if !totalOnly { if !totalOnly {
grpclog.Printf("%s: %d", gaugeResponse.Name, gaugeResponse.GetLongValue()) grpclog.Printf("%s: %d", gaugeResponse.Name, v)
} }
overallQPS += gaugeResponse.GetLongValue() overallQPS += v
} else {
grpclog.Printf("gauge %s is not a long value", gaugeResponse.Name)
} }
}
grpclog.Printf("overall qps: %d", overallQPS)
if rpcStatus != io.EOF { if rpcStatus != io.EOF {
grpclog.Fatalf("failed to finish server streaming: %v", rpcStatus) grpclog.Fatalf("failed to finish server streaming: %v", rpcStatus)
} }
grpclog.Printf("overall qps: %d", overallQPS)
} }
func main() { func main() {
flag.Parse() flag.Parse()
if len(*metricsServerAddressPtr) == 0 { if *metricsServerAddress == "" {
grpclog.Fatalf("Cannot connect to the Metrics server. Please pass the address of the metrics server to connect to via the 'metrics_server_address' flag") grpclog.Fatalf("Metrics server address is empty.")
} }
conn, err := grpc.Dial(*metricsServerAddressPtr, grpc.WithInsecure()) conn, err := grpc.Dial(*metricsServerAddress, grpc.WithInsecure())
if err != nil { if err != nil {
grpclog.Fatalf("cannot connect to metrics server: %v", err) grpclog.Fatalf("cannot connect to metrics server: %v", err)
} }
defer conn.Close() defer conn.Close()
c := metricspb.NewMetricsServiceClient(conn) c := metricspb.NewMetricsServiceClient(conn)
printMetrics(c, *totalOnlyPtr) printMetrics(c, *totalOnly)
} }