From 419de394cf7ab1b907a3f378583ca3e58af06b83 Mon Sep 17 00:00:00 2001 From: Cesar Ghali Date: Fri, 11 May 2018 14:16:43 -0700 Subject: [PATCH] Add AuthInfoFromContext utility API (#2062) --- credentials/alts/utils.go | 19 ++++++++++++++ credentials/alts/utils_test.go | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index 86d0a213..9f92eb1d 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -30,6 +30,9 @@ import ( "runtime" "strings" "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/peer" ) const ( @@ -115,3 +118,19 @@ func readManufacturer() ([]byte, error) { } return manufacturer, nil } + +// AuthInfoFromContext extracts the alts.AuthInfo object from the given context, +// if it exists. This API should be used by gRPC server RPC handlers to get +// information about the communicating peer. For client-side, use grpc.Peer() +// CallOption. +func AuthInfoFromContext(ctx context.Context) (AuthInfo, error) { + peer, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("no Peer found in Context") + } + altsAuthInfo, ok := peer.AuthInfo.(AuthInfo) + if !ok { + return nil, errors.New("no alts.AuthInfo found in Context") + } + return altsAuthInfo, nil +} diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 32c5e1bf..4724103f 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -22,6 +22,10 @@ import ( "io" "strings" "testing" + + "golang.org/x/net/context" + altspb "google.golang.org/grpc/credentials/alts/core/proto/grpc_gcp" + "google.golang.org/grpc/peer" ) func TestIsRunningOnGCP(t *testing.T) { @@ -64,3 +68,44 @@ func setup(testOS string, testReader io.Reader) func() { manufacturerReader = tmpReader } } + +func TestAuthInfoFromContext(t *testing.T) { + ctx := context.Background() + altsAuthInfo := &fakeALTSAuthInfo{} + p := &peer.Peer{ + AuthInfo: altsAuthInfo, + } + for _, tc := range []struct { + desc string + ctx context.Context + success bool + out AuthInfo + }{ + { + "working case", + peer.NewContext(ctx, p), + true, + altsAuthInfo, + }, + } { + authInfo, err := AuthInfoFromContext(tc.ctx) + if got, want := (err == nil), tc.success; got != want { + t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want) + } + if got, want := authInfo, tc.out; got != want { + t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want) + } + } +} + +type fakeALTSAuthInfo struct{} + +func (*fakeALTSAuthInfo) AuthType() string { return "" } +func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" } +func (*fakeALTSAuthInfo) RecordProtocol() string { return "" } +func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel { + return altspb.SecurityLevel_SECURITY_NONE +} +func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" } +func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" } +func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }