From 25c36604b9792bc619628183aaaf5fab5557677f Mon Sep 17 00:00:00 2001
From: yangzhouhan <yangzhouhan@gmail.com>
Date: Fri, 14 Aug 2015 16:22:19 -0700
Subject: [PATCH] add oauth2 and perrpc interop tests

---
 credentials/oauth/oauth.go | 15 +++++++
 interop/client/client.go   | 90 ++++++++++++++++++++++++++++++++++++--
 2 files changed, 102 insertions(+), 3 deletions(-)

diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go
index 28ed0c39..e213008e 100644
--- a/credentials/oauth/oauth.go
+++ b/credentials/oauth/oauth.go
@@ -91,6 +91,21 @@ func (j jwtAccess) GetRequestMetadata(ctx context.Context) (map[string]string, e
 	}, nil
 }
 
+// oauthAccess supplies credentials from a given token.
+type oauthAccess struct {
+	token oauth2.Token
+}
+
+func NewOauthAccess(token *oauth2.Token) credentials.Credentials {
+	return oauthAccess{token: *token}
+}
+
+func (oa oauthAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
+	return map[string]string{
+		"authorization": oa.token.TokenType + " " + oa.token.AccessToken,
+	}, nil
+}
+
 // NewComputeEngine constructs the credentials that fetches access tokens from
 // Google Compute Engine (GCE)'s metadata server. It is only valid to use this
 // if your program is running on a GCE instance.
diff --git a/interop/client/client.go b/interop/client/client.go
index 048ce000..c4c3cf73 100644
--- a/interop/client/client.go
+++ b/interop/client/client.go
@@ -44,6 +44,8 @@ import (
 
 	"github.com/golang/protobuf/proto"
 	"golang.org/x/net/context"
+	"golang.org/x/oauth2"
+	"golang.org/x/oauth2/google"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/credentials"
@@ -73,9 +75,11 @@ var (
         timeout_on_sleeping_server: fullduplex streaming;
         compute_engine_creds: large_unary with compute engine auth;
         service_account_creds: large_unary with service account auth;
-	jwt_token_creds: large_unary with jwt token auth;
-	cancel_after_begin: cancellation after metadata has been sent but before payloads are sent;
-	cancel_after_first_response: cancellation after receiving 1st message from the server.`)
+        jwt_token_creds: large_unary with jwt token auth;
+        per_rpc_creds: large_unary with per rpc token;
+        oauth2_token_creds: large_unary with oauth2 token auth;
+        cancel_after_begin: cancellation after metadata has been sent but before payloads are sent;
+        cancel_after_first_response: cancellation after receiving 1st message from the server.`)
 )
 
 var (
@@ -364,6 +368,72 @@ func doJWTTokenCreds(tc testpb.TestServiceClient) {
 	grpclog.Println("JWTtokenCreds done")
 }
 
+func getToken() *oauth2.Token {
+	jsonKey := getServiceAccountJSONKey()
+	config, err := google.JWTConfigFromJSON(jsonKey, *oauthScope)
+	if err != nil {
+		grpclog.Fatalf("Failed to get the config: %v", err)
+	}
+	token, err := config.TokenSource(context.Background()).Token()
+	if err != nil {
+		grpclog.Fatalf("Failed to get the token: %v", err)
+	}
+	return token
+}
+
+func doOauth2TokenCreds(tc testpb.TestServiceClient) {
+	pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
+	req := &testpb.SimpleRequest{
+		ResponseType:   testpb.PayloadType_COMPRESSABLE.Enum(),
+		ResponseSize:   proto.Int32(int32(largeRespSize)),
+		Payload:        pl,
+		FillUsername:   proto.Bool(true),
+		FillOauthScope: proto.Bool(true),
+	}
+	reply, err := tc.UnaryCall(context.Background(), req)
+	if err != nil {
+		grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
+	}
+	jsonKey := getServiceAccountJSONKey()
+	user := reply.GetUsername()
+	scope := reply.GetOauthScope()
+	if !strings.Contains(string(jsonKey), user) {
+		grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
+	}
+	if !strings.Contains(*oauthScope, scope) {
+		grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope)
+	}
+	grpclog.Println("Oauth2TokenCreds done")
+}
+
+func doPerRPCCreds(tc testpb.TestServiceClient) {
+	jsonKey := getServiceAccountJSONKey()
+	pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize)
+	req := &testpb.SimpleRequest{
+		ResponseType:   testpb.PayloadType_COMPRESSABLE.Enum(),
+		ResponseSize:   proto.Int32(int32(largeRespSize)),
+		Payload:        pl,
+		FillUsername:   proto.Bool(true),
+		FillOauthScope: proto.Bool(true),
+	}
+	token := getToken()
+	kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken}
+	ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
+	reply, err := tc.UnaryCall(ctx, req)
+	if err != nil {
+		grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
+	}
+	user := reply.GetUsername()
+	scope := reply.GetOauthScope()
+	if !strings.Contains(string(jsonKey), user) {
+		grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey)
+	}
+	if !strings.Contains(*oauthScope, scope) {
+		grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope)
+	}
+	grpclog.Println("PerRPCCreds done")
+}
+
 var (
 	testMetadata = metadata.MD{
 		"key1": []string{"value1"},
@@ -449,6 +519,9 @@ func main() {
 				grpclog.Fatalf("Failed to create JWT credentials: %v", err)
 			}
 			opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds))
+		} else if *testCase == "oauth2_token_creds" {
+			token := getToken()
+			opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(token)))
 		}
 	}
 	conn, err := grpc.Dial(serverAddr, opts...)
@@ -487,6 +560,17 @@ func main() {
 			grpclog.Fatalf("TLS is not enabled. TLS is required to execute jwt_token_creds test case.")
 		}
 		doJWTTokenCreds(tc)
+	case "per_rpc_creds":
+		if !*useTLS {
+			grpclog.Fatalf("TLS is not enabled. TLS is required to execute per_rpc_creds test case.")
+		}
+		doPerRPCCreds(tc)
+	case "oauth2_token_creds":
+		if !*useTLS {
+			grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_token_creds test case.")
+		}
+		doOauth2TokenCreds(tc)
+
 	case "cancel_after_begin":
 		doCancelAfterBegin(tc)
 	case "cancel_after_first_response":