Add AuthInfoFromContext utility API (#2062)
This commit is contained in:
@ -30,6 +30,9 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -115,3 +118,19 @@ func readManufacturer() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return manufacturer, nil
|
return manufacturer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthInfoFromContext extracts the alts.AuthInfo object from the given context,
|
||||||
|
// if it exists. This API should be used by gRPC server RPC handlers to get
|
||||||
|
// information about the communicating peer. For client-side, use grpc.Peer()
|
||||||
|
// CallOption.
|
||||||
|
func AuthInfoFromContext(ctx context.Context) (AuthInfo, error) {
|
||||||
|
peer, ok := peer.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("no Peer found in Context")
|
||||||
|
}
|
||||||
|
altsAuthInfo, ok := peer.AuthInfo.(AuthInfo)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("no alts.AuthInfo found in Context")
|
||||||
|
}
|
||||||
|
return altsAuthInfo, nil
|
||||||
|
}
|
||||||
|
@ -22,6 +22,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
altspb "google.golang.org/grpc/credentials/alts/core/proto/grpc_gcp"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsRunningOnGCP(t *testing.T) {
|
func TestIsRunningOnGCP(t *testing.T) {
|
||||||
@ -64,3 +68,44 @@ func setup(testOS string, testReader io.Reader) func() {
|
|||||||
manufacturerReader = tmpReader
|
manufacturerReader = tmpReader
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthInfoFromContext(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
altsAuthInfo := &fakeALTSAuthInfo{}
|
||||||
|
p := &peer.Peer{
|
||||||
|
AuthInfo: altsAuthInfo,
|
||||||
|
}
|
||||||
|
for _, tc := range []struct {
|
||||||
|
desc string
|
||||||
|
ctx context.Context
|
||||||
|
success bool
|
||||||
|
out AuthInfo
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"working case",
|
||||||
|
peer.NewContext(ctx, p),
|
||||||
|
true,
|
||||||
|
altsAuthInfo,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
authInfo, err := AuthInfoFromContext(tc.ctx)
|
||||||
|
if got, want := (err == nil), tc.success; got != want {
|
||||||
|
t.Errorf("%v: AuthInfoFromContext(_)=(err=nil)=%v, want %v", tc.desc, got, want)
|
||||||
|
}
|
||||||
|
if got, want := authInfo, tc.out; got != want {
|
||||||
|
t.Errorf("%v:, AuthInfoFromContext(_)=(%v, _), want (%v, _)", tc.desc, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeALTSAuthInfo struct{}
|
||||||
|
|
||||||
|
func (*fakeALTSAuthInfo) AuthType() string { return "" }
|
||||||
|
func (*fakeALTSAuthInfo) ApplicationProtocol() string { return "" }
|
||||||
|
func (*fakeALTSAuthInfo) RecordProtocol() string { return "" }
|
||||||
|
func (*fakeALTSAuthInfo) SecurityLevel() altspb.SecurityLevel {
|
||||||
|
return altspb.SecurityLevel_SECURITY_NONE
|
||||||
|
}
|
||||||
|
func (*fakeALTSAuthInfo) PeerServiceAccount() string { return "" }
|
||||||
|
func (*fakeALTSAuthInfo) LocalServiceAccount() string { return "" }
|
||||||
|
func (*fakeALTSAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { return nil }
|
||||||
|
Reference in New Issue
Block a user