Add AuthInfoFromContext utility API (#2062)
This commit is contained in:
@ -30,6 +30,9 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -115,3 +118,19 @@ func readManufacturer() ([]byte, error) {
|
||||
}
|
||||
return manufacturer, nil
|
||||
}
|
||||
|
||||
// AuthInfoFromContext extracts the alts.AuthInfo object from the given context,
|
||||
// if it exists. This API should be used by gRPC server RPC handlers to get
|
||||
// information about the communicating peer. For client-side, use grpc.Peer()
|
||||
// CallOption.
|
||||
func AuthInfoFromContext(ctx context.Context) (AuthInfo, error) {
|
||||
peer, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("no Peer found in Context")
|
||||
}
|
||||
altsAuthInfo, ok := peer.AuthInfo.(AuthInfo)
|
||||
if !ok {
|
||||
return nil, errors.New("no alts.AuthInfo found in Context")
|
||||
}
|
||||
return altsAuthInfo, nil
|
||||
}
|
||||
|
@ -22,6 +22,10 @@ import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
altspb "google.golang.org/grpc/credentials/alts/core/proto/grpc_gcp"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
func TestIsRunningOnGCP(t *testing.T) {
|
||||
@ -64,3 +68,44 @@ func setup(testOS string, testReader io.Reader) func() {
|
||||
manufacturerReader = tmpReader
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInfoFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
altsAuthInfo := &fakeALTSAuthInfo{}
|
||||
p := &peer.Peer{
|
||||
AuthInfo: altsAuthInfo,
|
||||
}
|
||||
for _, tc := range []struct {
|
||||
desc string
|
||||
ctx context.Context
|
||||
success bool
|
||||
out AuthInfo
|
||||
}{
|
||||
{
|
||||
"working case",
|
||||
peer.NewContext(ctx, p),
|
||||
true,
|
||||
altsAuthInfo,
|
||||
},
|
||||
} {
|
||||
authInfo, err := AuthInfoFromContext(tc.ctx)
|
||||
if got, want := (err == nil), tc.success; got != want {
|
||||
t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
|
||||
}
|
||||
if got, want := authInfo, tc.out; got != want {
|
||||
t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeALTSAuthInfo struct{}
|
||||
|
||||
func (*fakeALTSAuthInfo) AuthType() string { return "" }
|
||||
func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
|
||||
func (*fakeALTSAuthInfo) RecordProtocol() string { return "" }
|
||||
func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
|
||||
return altspb.SecurityLevel_SECURITY_NONE
|
||||
}
|
||||
func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" }
|
||||
func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" }
|
||||
func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
|
||||
|
Reference in New Issue
Block a user