diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go new file mode 100644 index 00000000..a5881656 --- /dev/null +++ b/credentials/alts/alts.go @@ -0,0 +1,286 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package alts implements the ALTS credential support by gRPC library, which +// encapsulates all the state needed by a client to authenticate with a server +// using ALTS and make various assertions, e.g., about the client's identity, +// role, or whether it is authorized to make a particular call. +// This package is experimental. +package alts + +import ( + "errors" + "flag" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/alts/core" + "google.golang.org/grpc/credentials/alts/core/handshaker" + "google.golang.org/grpc/credentials/alts/core/handshaker/service" + altspb "google.golang.org/grpc/credentials/alts/core/proto" + "google.golang.org/grpc/grpclog" +) + +const ( + // defaultTimeout specifies the server handshake timeout. + defaultTimeout = 30.0 * time.Second + // The following constants specify the minimum and maximum acceptable + // protocol versions. + protocolVersionMaxMajor = 2 + protocolVersionMaxMinor = 1 + protocolVersionMinMajor = 2 + protocolVersionMinMinor = 1 +) + +var ( + enableUntrustedALTS = flag.Bool("enable_untrusted_alts", false, "Enables ALTS in untrusted mode. Enabling this mode is risky since we cannot ensure that the application is running on GCP with a trusted handshaker service.") + once sync.Once + maxRPCVersion = &altspb.RpcProtocolVersions_Version{ + Major: protocolVersionMaxMajor, + Minor: protocolVersionMaxMinor, + } + minRPCVersion = &altspb.RpcProtocolVersions_Version{ + Major: protocolVersionMinMajor, + Minor: protocolVersionMinMinor, + } + // ErrUntrustedPlatform is returned from ClientHandshake and + // ServerHandshake is running on a platform where the trustworthiness of + // the handshaker service is not guaranteed. + ErrUntrustedPlatform = errors.New("untrusted platform, use enable_untrusted_alts flag at your own risk") +) + +// AuthInfo exposes security information from the ALTS handshake to the +// application. This interface is to be implemented by ALTS. Users should not +// need a brand new implementation of this interface. For situations like +// testing, any new implementation should embed this interface. This allows +// ALTS to add new methods to this interface. +type AuthInfo interface { + // ApplicationProtocol returns application protocol negotiated for the + // ALTS connection. + ApplicationProtocol() string + // RecordProtocol returns the record protocol negotiated for the ALTS + // connection. + RecordProtocol() string + // SecurityLevel returns the security level of the created ALTS secure + // channel. + SecurityLevel() altspb.SecurityLevel + // PeerServiceAccount returns the peer service account. + PeerServiceAccount() string + // LocalServiceAccount returns the local service account. + LocalServiceAccount() string + // PeerRPCVersions returns the RPC version supported by the peer. + PeerRPCVersions() *altspb.RpcProtocolVersions +} + +// altsTC is the credentials required for authenticating a connection using ALTS. +// It implements credentials.TransportCredentials interface. +type altsTC struct { + info *credentials.ProtocolInfo + hsAddr string + side core.Side + accounts []string +} + +// NewClientALTS constructs a client-side ALTS TransportCredentials object. +func NewClientALTS(targetServiceAccounts []string) credentials.TransportCredentials { + return newALTS(core.ClientSide, targetServiceAccounts) +} + +// NewServerALTS constructs a server-side ALTS TransportCredentials object. +func NewServerALTS() credentials.TransportCredentials { + return newALTS(core.ServerSide, nil) +} + +func newALTS(side core.Side, accounts []string) credentials.TransportCredentials { + // Make sure flags are parsed before accessing enableUntrustedALTS. + once.Do(func() { + flag.Parse() + vmOnGCP = isRunningOnGCP() + }) + if *enableUntrustedALTS { + grpclog.Warning("untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS handshaker service.") + } + + return &altsTC{ + info: &credentials.ProtocolInfo{ + SecurityProtocol: "alts", + SecurityVersion: "1.0", + }, + side: side, + accounts: accounts, + } +} + +// ClientHandshake implements the client side handshake protocol. +func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { + if !*enableUntrustedALTS && !vmOnGCP { + return nil, nil, ErrUntrustedPlatform + } + + // Connecting to ALTS handshaker service. + hsConn, err := service.Dial() + if err != nil { + return nil, nil, err + } + // Do not close hsConn since it is shared with other handshakes. + + // Possible context leak: + // The cancel function for the child context we create will only be + // called a non-nil error is returned. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + + opts := handshaker.DefaultClientHandshakerOptions() + opts.TargetServiceAccounts = g.accounts + opts.RPCVersions = &altspb.RpcProtocolVersions{ + MaxRpcVersion: maxRPCVersion, + MinRpcVersion: minRPCVersion, + } + chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts) + defer func() { + if err != nil { + chs.Close() + } + }() + if err != nil { + return nil, nil, err + } + secConn, authInfo, err := chs.ClientHandshake(ctx) + if err != nil { + return nil, nil, err + } + altsAuthInfo, ok := authInfo.(AuthInfo) + if !ok { + return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo") + } + match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) + if !match { + return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) + } + return secConn, authInfo, nil +} + +// ServerHandshake implements the server side ALTS handshaker. +func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { + if !*enableUntrustedALTS && !vmOnGCP { + return nil, nil, ErrUntrustedPlatform + } + // Connecting to ALTS handshaker service. + hsConn, err := service.Dial() + if err != nil { + return nil, nil, err + } + // Do not close hsConn since it's shared with other handshakes. + + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + opts := handshaker.DefaultServerHandshakerOptions() + opts.RPCVersions = &altspb.RpcProtocolVersions{ + MaxRpcVersion: maxRPCVersion, + MinRpcVersion: minRPCVersion, + } + shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts) + defer func() { + if err != nil { + shs.Close() + } + }() + if err != nil { + return nil, nil, err + } + secConn, authInfo, err := shs.ServerHandshake(ctx) + if err != nil { + return nil, nil, err + } + altsAuthInfo, ok := authInfo.(AuthInfo) + if !ok { + return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo") + } + match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) + if !match { + return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions()) + } + return secConn, authInfo, nil +} + +func (g *altsTC) Info() credentials.ProtocolInfo { + return *g.info +} + +func (g *altsTC) Clone() credentials.TransportCredentials { + info := *g.info + return &altsTC{ + info: &info, + } +} + +func (g *altsTC) OverrideServerName(serverNameOverride string) error { + g.info.ServerName = serverNameOverride + return nil +} + +// compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2. +func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int { + switch { + case v1.GetMajor() > v2.GetMajor(), + v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor(): + return 1 + case v1.GetMajor() < v2.GetMajor(), + v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor(): + return -1 + } + return 0 +} + +// checkRPCVersions performs a version check between local and peer rpc protocol +// versions. This function returns true if the check passes which means both +// parties agreed on a common rpc protocol to use, and false otherwise. The +// function also returns the highest common RPC protocol version both parties +// agreed on. +func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) { + if local == nil || peer == nil { + grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.") + return false, nil + } + + // maxCommonVersion is MIN(local.max, peer.max). + maxCommonVersion := local.GetMaxRpcVersion() + if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 { + maxCommonVersion = peer.GetMaxRpcVersion() + } + + // minCommonVersion is MAX(local.min, peer.min). + minCommonVersion := peer.GetMinRpcVersion() + if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 { + minCommonVersion = local.GetMinRpcVersion() + } + + if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 { + return false, nil + } + return true, maxCommonVersion +} diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go new file mode 100644 index 00000000..0b884818 --- /dev/null +++ b/credentials/alts/alts_test.go @@ -0,0 +1,246 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package alts + +import ( + "testing" + + "github.com/golang/protobuf/proto" + altspb "google.golang.org/grpc/credentials/alts/core/proto" +) + +func TestInfoServerName(t *testing.T) { + // This is not testing any handshaker functionality, so it's fine to only + // use NewServerALTS and not NewClientALTS. + alts := NewServerALTS() + if got, want := alts.Info().ServerName, ""; got != want { + t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want) + } +} + +func TestOverrideServerName(t *testing.T) { + wantServerName := "server.name" + // This is not testing any handshaker functionality, so it's fine to only + // use NewServerALTS and not NewClientALTS. + c := NewServerALTS() + c.OverrideServerName(wantServerName) + if got, want := c.Info().ServerName, wantServerName; got != want { + t.Fatalf("c.Info().ServerName = %v, want %v", got, want) + } +} + +func TestClone(t *testing.T) { + wantServerName := "server.name" + // This is not testing any handshaker functionality, so it's fine to only + // use NewServerALTS and not NewClientALTS. + c := NewServerALTS() + c.OverrideServerName(wantServerName) + cc := c.Clone() + if got, want := cc.Info().ServerName, wantServerName; got != want { + t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) + } + cc.OverrideServerName("") + if got, want := c.Info().ServerName, wantServerName; got != want { + t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) + } + if got, want := cc.Info().ServerName, ""; got != want { + t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) + } +} + +func TestInfo(t *testing.T) { + // This is not testing any handshaker functionality, so it's fine to only + // use NewServerALTS and not NewClientALTS. + c := NewServerALTS() + info := c.Info() + if got, want := info.ProtocolVersion, ""; got != want { + t.Errorf("info.ProtocolVersion=%v, want %v", got, want) + } + if got, want := info.SecurityProtocol, "alts"; got != want { + t.Errorf("info.SecurityProtocol=%v, want %v", got, want) + } + if got, want := info.SecurityVersion, "1.0"; got != want { + t.Errorf("info.SecurityVersion=%v, want %v", got, want) + } + if got, want := info.ServerName, ""; got != want { + t.Errorf("info.ServerName=%v, want %v", got, want) + } +} + +func TestCompareRPCVersions(t *testing.T) { + for _, tc := range []struct { + v1 *altspb.RpcProtocolVersions_Version + v2 *altspb.RpcProtocolVersions_Version + output int + }{ + { + version(3, 2), + version(2, 1), + 1, + }, + { + version(3, 2), + version(3, 1), + 1, + }, + { + version(2, 1), + version(3, 2), + -1, + }, + { + version(3, 1), + version(3, 2), + -1, + }, + { + version(3, 2), + version(3, 2), + 0, + }, + } { + if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want { + t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want) + } + } +} + +func TestCheckRPCVersions(t *testing.T) { + for _, tc := range []struct { + desc string + local *altspb.RpcProtocolVersions + peer *altspb.RpcProtocolVersions + output bool + maxCommonVersion *altspb.RpcProtocolVersions_Version + }{ + { + "local.max > peer.max and local.min > peer.min", + versions(2, 1, 3, 2), + versions(1, 2, 2, 1), + true, + version(2, 1), + }, + { + "local.max > peer.max and local.min < peer.min", + versions(1, 2, 3, 2), + versions(2, 1, 2, 1), + true, + version(2, 1), + }, + { + "local.max > peer.max and local.min = peer.min", + versions(2, 1, 3, 2), + versions(2, 1, 2, 1), + true, + version(2, 1), + }, + { + "local.max < peer.max and local.min > peer.min", + versions(2, 1, 2, 1), + versions(1, 2, 3, 2), + true, + version(2, 1), + }, + { + "local.max = peer.max and local.min > peer.min", + versions(2, 1, 2, 1), + versions(1, 2, 2, 1), + true, + version(2, 1), + }, + { + "local.max < peer.max and local.min < peer.min", + versions(1, 2, 2, 1), + versions(2, 1, 3, 2), + true, + version(2, 1), + }, + { + "local.max < peer.max and local.min = peer.min", + versions(1, 2, 2, 1), + versions(1, 2, 3, 2), + true, + version(2, 1), + }, + { + "local.max = peer.max and local.min < peer.min", + versions(1, 2, 2, 1), + versions(2, 1, 2, 1), + true, + version(2, 1), + }, + { + "all equal", + versions(2, 1, 2, 1), + versions(2, 1, 2, 1), + true, + version(2, 1), + }, + { + "max is smaller than min", + versions(2, 1, 1, 2), + versions(2, 1, 1, 2), + false, + nil, + }, + { + "no overlap, local > peer", + versions(4, 3, 6, 5), + versions(1, 0, 2, 1), + false, + nil, + }, + { + "no overlap, local < peer", + versions(1, 0, 2, 1), + versions(4, 3, 6, 5), + false, + nil, + }, + { + "no overlap, max < min", + versions(6, 5, 4, 3), + versions(2, 1, 1, 0), + false, + nil, + }, + } { + output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer) + if got, want := output, tc.output; got != want { + t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want) + } + if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) { + t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want) + } + } +} + +func version(major, minor uint32) *altspb.RpcProtocolVersions_Version { + return &altspb.RpcProtocolVersions_Version{ + Major: major, + Minor: minor, + } +} + +func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions { + return &altspb.RpcProtocolVersions{ + MinRpcVersion: version(minMajor, minMinor), + MaxRpcVersion: version(maxMajor, maxMinor), + } +} diff --git a/credentials/alts/core/authinfo/authinfo.go b/credentials/alts/core/authinfo/authinfo.go new file mode 100644 index 00000000..284edbff --- /dev/null +++ b/credentials/alts/core/authinfo/authinfo.go @@ -0,0 +1,87 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package authinfo provide authentication information returned by handshakers. +package authinfo + +import ( + "google.golang.org/grpc/credentials" + altspb "google.golang.org/grpc/credentials/alts/core/proto" +) + +var _ credentials.AuthInfo = (*altsAuthInfo)(nil) + +// altsAuthInfo exposes security information from the ALTS handshake to the +// application. altsAuthInfo is immutable and implements credentials.AuthInfo. +type altsAuthInfo struct { + p *altspb.AltsContext +} + +// New returns a new altsAuthInfo object given handshaker results. +func New(result *altspb.HandshakerResult) credentials.AuthInfo { + return newAuthInfo(result) +} + +func newAuthInfo(result *altspb.HandshakerResult) *altsAuthInfo { + return &altsAuthInfo{ + p: &altspb.AltsContext{ + ApplicationProtocol: result.GetApplicationProtocol(), + RecordProtocol: result.GetRecordProtocol(), + // TODO: assign security level from result. + SecurityLevel: altspb.SecurityLevel_INTEGRITY_AND_PRIVACY, + PeerServiceAccount: result.GetPeerIdentity().GetServiceAccount(), + LocalServiceAccount: result.GetLocalIdentity().GetServiceAccount(), + PeerRpcVersions: result.GetPeerRpcVersions(), + }, + } +} + +// AuthType identifies the context as providing ALTS authentication information. +func (s *altsAuthInfo) AuthType() string { + return "alts" +} + +// ApplicationProtocol returns the context's application protocol. +func (s *altsAuthInfo) ApplicationProtocol() string { + return s.p.GetApplicationProtocol() +} + +// RecordProtocol returns the context's record protocol. +func (s *altsAuthInfo) RecordProtocol() string { + return s.p.GetRecordProtocol() +} + +// SecurityLevel returns the context's security level. +func (s *altsAuthInfo) SecurityLevel() altspb.SecurityLevel { + return s.p.GetSecurityLevel() +} + +// PeerServiceAccount returns the context's peer service account. +func (s *altsAuthInfo) PeerServiceAccount() string { + return s.p.GetPeerServiceAccount() +} + +// LocalServiceAccount returns the context's local service account. +func (s *altsAuthInfo) LocalServiceAccount() string { + return s.p.GetLocalServiceAccount() +} + +// PeerRPCVersions returns the context's peer RPC versions. +func (s *altsAuthInfo) PeerRPCVersions() *altspb.RpcProtocolVersions { + return s.p.GetPeerRpcVersions() +} diff --git a/credentials/alts/core/authinfo/authinfo_test.go b/credentials/alts/core/authinfo/authinfo_test.go new file mode 100644 index 00000000..b3d430a3 --- /dev/null +++ b/credentials/alts/core/authinfo/authinfo_test.go @@ -0,0 +1,134 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authinfo + +import ( + "reflect" + "testing" + + altspb "google.golang.org/grpc/credentials/alts/core/proto" +) + +const ( + testAppProtocol = "my_app" + testRecordProtocol = "very_secure_protocol" + testPeerAccount = "peer_service_account" + testLocalAccount = "local_service_account" + testPeerHostname = "peer_hostname" + testLocalHostname = "local_hostname" +) + +func TestALTSAuthInfo(t *testing.T) { + for _, tc := range []struct { + result *altspb.HandshakerResult + outAppProtocol string + outRecordProtocol string + outSecurityLevel altspb.SecurityLevel + outPeerAccount string + outLocalAccount string + outPeerRPCVersions *altspb.RpcProtocolVersions + }{ + { + &altspb.HandshakerResult{ + ApplicationProtocol: testAppProtocol, + RecordProtocol: testRecordProtocol, + PeerIdentity: &altspb.Identity{ + IdentityOneof: &altspb.Identity_ServiceAccount{ + ServiceAccount: testPeerAccount, + }, + }, + LocalIdentity: &altspb.Identity{ + IdentityOneof: &altspb.Identity_ServiceAccount{ + ServiceAccount: testLocalAccount, + }, + }, + }, + testAppProtocol, + testRecordProtocol, + altspb.SecurityLevel_INTEGRITY_AND_PRIVACY, + testPeerAccount, + testLocalAccount, + nil, + }, + { + &altspb.HandshakerResult{ + ApplicationProtocol: testAppProtocol, + RecordProtocol: testRecordProtocol, + PeerIdentity: &altspb.Identity{ + IdentityOneof: &altspb.Identity_Hostname{ + Hostname: testPeerHostname, + }, + }, + LocalIdentity: &altspb.Identity{ + IdentityOneof: &altspb.Identity_Hostname{ + Hostname: testLocalHostname, + }, + }, + PeerRpcVersions: &altspb.RpcProtocolVersions{ + MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ + Major: 20, + Minor: 21, + }, + MinRpcVersion: &altspb.RpcProtocolVersions_Version{ + Major: 10, + Minor: 11, + }, + }, + }, + testAppProtocol, + testRecordProtocol, + altspb.SecurityLevel_INTEGRITY_AND_PRIVACY, + "", + "", + &altspb.RpcProtocolVersions{ + MaxRpcVersion: &altspb.RpcProtocolVersions_Version{ + Major: 20, + Minor: 21, + }, + MinRpcVersion: &altspb.RpcProtocolVersions_Version{ + Major: 10, + Minor: 11, + }, + }, + }, + } { + authInfo := newAuthInfo(tc.result) + if got, want := authInfo.AuthType(), "alts"; got != want { + t.Errorf("authInfo.AuthType()=%v, want %v", got, want) + } + if got, want := authInfo.ApplicationProtocol(), tc.outAppProtocol; got != want { + t.Errorf("authInfo.ApplicationProtocol()=%v, want %v", got, want) + } + if got, want := authInfo.RecordProtocol(), tc.outRecordProtocol; got != want { + t.Errorf("authInfo.RecordProtocol()=%v, want %v", got, want) + } + if got, want := authInfo.SecurityLevel(), tc.outSecurityLevel; got != want { + t.Errorf("authInfo.SecurityLevel()=%v, want %v", got, want) + } + if got, want := authInfo.PeerServiceAccount(), tc.outPeerAccount; got != want { + t.Errorf("authInfo.PeerServiceAccount()=%v, want %v", got, want) + } + if got, want := authInfo.LocalServiceAccount(), tc.outLocalAccount; got != want { + t.Errorf("authInfo.LocalServiceAccount()=%v, want %v", got, want) + } + if got, want := authInfo.PeerRPCVersions(), tc.outPeerRPCVersions; !reflect.DeepEqual(got, want) { + t.Errorf("authinfo.PeerRpcVersions()=%v, want %v", got, want) + } + } +} diff --git a/credentials/alts/core/common.go b/credentials/alts/core/common.go new file mode 100644 index 00000000..0112d51c --- /dev/null +++ b/credentials/alts/core/common.go @@ -0,0 +1,68 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package core contains common core functionality for ALTS. +// Disclaimer: users should NEVER reference this package directly. +package core + +import ( + "net" + + "golang.org/x/net/context" + "google.golang.org/grpc/credentials" +) + +const ( + // ClientSide identifies the client in this communication. + ClientSide Side = iota + // ServerSide identifies the server in this communication. + ServerSide +) + +// PeerNotRespondingError is returned when a peer server is not responding +// after a channel has been established. It is treated as a temporary connection +// error and re-connection to the server should be attempted. +var PeerNotRespondingError = &peerNotRespondingError{} + +// Side identifies the party's role: client or server. +type Side int + +type peerNotRespondingError struct{} + +// Return an error message for the purpose of logging. +func (e *peerNotRespondingError) Error() string { + return "peer server is not responding and re-connection should be attempted." +} + +// Temporary indicates if this connection error is temporary or fatal. +func (e *peerNotRespondingError) Temporary() bool { + return true +} + +// Handshaker defines a ALTS handshaker interface. +type Handshaker interface { + // ClientHandshake starts and completes a client-side handshaking and + // returns a secure connection and corresponding auth information. + ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) + // ServerHandshake starts and completes a server-side handshaking and + // returns a secure connection and corresponding auth information. + ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) + // Close terminates the Handshaker. It should be called when the caller + // obtains the secure connection. + Close() +} diff --git a/credentials/alts/core/conn/aeadrekey.go b/credentials/alts/core/conn/aeadrekey.go new file mode 100644 index 00000000..43726e87 --- /dev/null +++ b/credentials/alts/core/conn/aeadrekey.go @@ -0,0 +1,131 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "fmt" + "strconv" +) + +// rekeyAEAD holds the necessary information for an AEAD based on +// AES-GCM that performs nonce-based key derivation and XORs the +// nonce with a random mask. +type rekeyAEAD struct { + kdfKey []byte + kdfCounter []byte + nonceMask []byte + nonceBuf []byte + gcmAEAD cipher.AEAD +} + +// KeySizeError signals that the given key does not have the correct size. +type KeySizeError int + +func (k KeySizeError) Error() string { + return "alts/conn: invalid key size " + strconv.Itoa(int(k)) +} + +// newRekeyAEAD creates a new instance of aes128gcm with rekeying. +// The key argument should be 44 bytes, the first 32 bytes are used as a key +// for HKDF-expand and the remainining 12 bytes are used as a random mask for +// the counter. +func newRekeyAEAD(key []byte) (*rekeyAEAD, error) { + k := len(key) + if k != kdfKeyLen+nonceLen { + return nil, KeySizeError(k) + } + return &rekeyAEAD{ + kdfKey: key[:kdfKeyLen], + kdfCounter: make([]byte, kdfCounterLen), + nonceMask: key[kdfKeyLen:], + nonceBuf: make([]byte, nonceLen), + gcmAEAD: nil, + }, nil +} + +// Seal rekeys if nonce[2:8] is different than in the last call, masks the nonce, +// and calls Seal for aes128gcm. +func (s *rekeyAEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + if err := s.rekeyIfRequired(nonce); err != nil { + panic(fmt.Sprintf("Rekeying failed with: %s", err.Error())) + } + maskNonce(s.nonceBuf, nonce, s.nonceMask) + return s.gcmAEAD.Seal(dst, s.nonceBuf, plaintext, additionalData) +} + +// Open rekeys if nonce[2:8] is different than in the last call, masks the nonce, +// and calls Open for aes128gcm. +func (s *rekeyAEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if err := s.rekeyIfRequired(nonce); err != nil { + return nil, err + } + maskNonce(s.nonceBuf, nonce, s.nonceMask) + return s.gcmAEAD.Open(dst, s.nonceBuf, ciphertext, additionalData) +} + +// rekeyIfRequired creates a new aes128gcm AEAD if the existing AEAD is nil +// or cannot be used with given nonce. +func (s *rekeyAEAD) rekeyIfRequired(nonce []byte) error { + newKdfCounter := nonce[kdfCounterOffset : kdfCounterOffset+kdfCounterLen] + if s.gcmAEAD != nil && bytes.Equal(newKdfCounter, s.kdfCounter) { + return nil + } + copy(s.kdfCounter, newKdfCounter) + a, err := aes.NewCipher(hkdfExpand(s.kdfKey, s.kdfCounter)) + if err != nil { + return err + } + s.gcmAEAD, err = cipher.NewGCM(a) + return err +} + +// maskNonce XORs the given nonce with the mask and stores the result in dst. +func maskNonce(dst, nonce, mask []byte) { + nonce1 := binary.LittleEndian.Uint64(nonce[:sizeUint64]) + nonce2 := binary.LittleEndian.Uint32(nonce[sizeUint64:]) + mask1 := binary.LittleEndian.Uint64(mask[:sizeUint64]) + mask2 := binary.LittleEndian.Uint32(mask[sizeUint64:]) + binary.LittleEndian.PutUint64(dst[:sizeUint64], nonce1^mask1) + binary.LittleEndian.PutUint32(dst[sizeUint64:], nonce2^mask2) +} + +// NonceSize returns the required nonce size. +func (s *rekeyAEAD) NonceSize() int { + return s.gcmAEAD.NonceSize() +} + +// Overhead returns the ciphertext overhead. +func (s *rekeyAEAD) Overhead() int { + return s.gcmAEAD.Overhead() +} + +// hkdfExpand computes the first 16 bytes of the HKDF-expand function +// defined in RFC5869. +func hkdfExpand(key, info []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(info) + mac.Write([]byte{0x01}[:]) + return mac.Sum(nil)[:aeadKeyLen] +} diff --git a/credentials/alts/core/conn/aeadrekey_test.go b/credentials/alts/core/conn/aeadrekey_test.go new file mode 100644 index 00000000..d639b38a --- /dev/null +++ b/credentials/alts/core/conn/aeadrekey_test.go @@ -0,0 +1,263 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "bytes" + "encoding/hex" + "testing" +) + +// cryptoTestVector is struct for a rekey test vector +type rekeyAEADTestVector struct { + desc string + key, nonce, plaintext, aad, ciphertext []byte +} + +// Test encrypt and decrypt using (adapted) test vectors for AES-GCM. +func TestAES128GCMRekeyEncrypt(t *testing.T) { + for _, test := range []rekeyAEADTestVector{ + // NIST vectors from: + // http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf + // + // IEEE vectors from: + // http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf + // + // Key expanded by setting + // expandedKey = (key || + // key ^ {0x01,..,0x01} || + // key ^ {0x02,..,0x02})[0:44]. + { + desc: "Derived from NIST test vector 1", + key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"), + nonce: dehex("000000000000000000000000"), + aad: dehex(""), + plaintext: dehex(""), + ciphertext: dehex("85e873e002f6ebdc4060954eb8675508"), + }, + { + desc: "Derived from NIST test vector 2", + key: dehex("0000000000000000000000000000000001010101010101010101010101010101020202020202020202020202"), + nonce: dehex("000000000000000000000000"), + aad: dehex(""), + plaintext: dehex("00000000000000000000000000000000"), + ciphertext: dehex("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11"), + }, + { + desc: "Derived from NIST test vector 3", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("cafebabefacedbaddecaf888"), + aad: dehex(""), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255"), + ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f3d"), + }, + { + desc: "Derived from NIST test vector 4", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("cafebabefacedbaddecaf888"), + aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"), + ciphertext: dehex("1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4ac8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5"), + }, + { + desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 15)", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("ca7ebabefacedbaddecaf888"), + aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"), + ciphertext: dehex("e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1"), + }, + { + desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 16)", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("cafebbbefacedbaddecaf888"), + aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"), + ciphertext: dehex("c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab463bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f"), + }, + { + desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 63)", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("cafebabefacedb2ddecaf888"), + aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"), + ciphertext: dehex("8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08"), + }, + { + desc: "Derived from adapted NIST test vector 4 for KDF counter boundary (flip nonce bit 64)", + key: dehex("feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f688d96"), + nonce: dehex("cafebabefacedbaddfcaf888"), + aad: dehex("feedfacedeadbeeffeedfacedeadbeefabaddad2"), + plaintext: dehex("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39"), + ciphertext: dehex("fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78"), + }, + { + desc: "Derived from IEEE 2.1.1 54-byte auth", + key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"), + nonce: dehex("12153524c0895e81b2c28465"), + aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"), + plaintext: dehex(""), + ciphertext: dehex("3ea0b584f3c85e93f9320ea591699efb"), + }, + { + desc: "Derived from IEEE 2.1.2 54-byte auth", + key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"), + nonce: dehex("12153524c0895e81b2c28465"), + aad: dehex("d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001"), + plaintext: dehex(""), + ciphertext: dehex("294e028bf1fe6f14c4e8f7305c933eb5"), + }, + { + desc: "Derived from IEEE 2.2.1 60-byte crypt", + key: dehex("ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d600dde"), + nonce: dehex("12153524c0895e81b2c28465"), + aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"), + ciphertext: dehex("db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6"), + }, + { + desc: "Derived from IEEE 2.2.2 60-byte crypt", + key: dehex("e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97a50755"), + nonce: dehex("12153524c0895e81b2c28465"), + aad: dehex("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0002"), + ciphertext: dehex("1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55b8f8133e77f659132502bad63f5713d57d0c11e0f871ed"), + }, + { + desc: "Derived from IEEE 2.3.1 60-byte auth", + key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"), + nonce: dehex("f0761e8dcd3d000176d457ed"), + aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"), + plaintext: dehex(""), + ciphertext: dehex("58837a10562b0f1f8edbe58ca55811d3"), + }, + { + desc: "Derived from IEEE 2.3.2 60-byte auth", + key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"), + nonce: dehex("f0761e8dcd3d000176d457ed"), + aad: dehex("e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a0003"), + plaintext: dehex(""), + ciphertext: dehex("c2722ff6ca29a257718a529d1f0c6a3b"), + }, + { + desc: "Derived from IEEE 2.4.1 54-byte crypt", + key: dehex("071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fccecd3f07"), + nonce: dehex("f0761e8dcd3d000176d457ed"), + aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"), + ciphertext: dehex("fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e0528a1f4c644a949fc31dc9279706ddba33f"), + }, + { + desc: "Derived from IEEE 2.4.2 54-byte crypt", + key: dehex("691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365ff1ea2"), + nonce: dehex("f0761e8dcd3d000176d457ed"), + aad: dehex("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340004"), + ciphertext: dehex("b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4f878e6cc3577768085c50a0e98fda6cbb8"), + }, + { + desc: "Derived from IEEE 2.5.1 65-byte auth", + key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"), + nonce: dehex("7cfde9f9e33724c68932d612"), + aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"), + plaintext: dehex(""), + ciphertext: dehex("cca20eecda6283f09bb3543dd99edb9b"), + }, + { + desc: "Derived from IEEE 2.5.2 65-byte auth", + key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"), + nonce: dehex("7cfde9f9e33724c68932d612"), + aad: dehex("84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f0005"), + plaintext: dehex(""), + ciphertext: dehex("b232cc1da5117bf15003734fa599d271"), + }, + { + desc: "Derived from IEEE 2.6.1 61-byte crypt", + key: dehex("013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d846f0eb9"), + nonce: dehex("7cfde9f9e33724c68932d612"), + aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"), + ciphertext: dehex("ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c3532ecd5d09a056d70be583f0d10829d9387d07d33d872e490"), + }, + { + desc: "Derived from IEEE 2.6.2 61-byte crypt", + key: dehex("83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2d89068"), + nonce: dehex("7cfde9f9e33724c68932d612"), + aad: dehex("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b0006"), + ciphertext: dehex("0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4"), + }, + { + desc: "Derived from IEEE 2.7.1 79-byte crypt", + key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"), + nonce: dehex("7ae8e2ca4ec500012e58495c"), + aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"), + plaintext: dehex(""), + ciphertext: dehex("813f0e630f96fb2d030f58d83f5cdfd0"), + }, + { + desc: "Derived from IEEE 2.7.2 79-byte crypt", + key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"), + nonce: dehex("7ae8e2ca4ec500012e58495c"), + aad: dehex("68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d0007"), + plaintext: dehex(""), + ciphertext: dehex("77e5a44c21eb07188aacbd74d1980e97"), + }, + { + desc: "Derived from IEEE 2.8.1 61-byte crypt", + key: dehex("88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f47058ab"), + nonce: dehex("7ae8e2ca4ec500012e58495c"), + aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"), + ciphertext: dehex("958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb5939462541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821f4"), + }, + { + desc: "Derived from IEEE 2.8.2 61-byte crypt", + key: dehex("4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476fab7ba"), + nonce: dehex("7ae8e2ca4ec500012e58495c"), + aad: dehex("68f2e77696ce7ae8e2ca4ec588e54d002e58495c"), + plaintext: dehex("08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748490008"), + ciphertext: dehex("b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b00"), + }} { + aead, err := newRekeyAEAD(test.key) + if err != nil { + t.Fatal("unexpected failure in newRekeyAEAD: ", err.Error()) + } + if got := aead.Seal(nil, test.nonce, test.plaintext, test.aad); !bytes.Equal(got, test.ciphertext) { + t.Errorf("Unexpected ciphertext for test vector '%s':\nciphertext=%s\nwant= %s", + test.desc, hex.EncodeToString(got), hex.EncodeToString(test.ciphertext)) + } + if got, err := aead.Open(nil, test.nonce, test.ciphertext, test.aad); err != nil || !bytes.Equal(got, test.plaintext) { + t.Errorf("Unexpected plaintext for test vector '%s':\nplaintext=%s (err=%v)\nwant= %s", + test.desc, hex.EncodeToString(got), err, hex.EncodeToString(test.plaintext)) + } + + } +} + +func dehex(s string) []byte { + if len(s) == 0 { + return make([]byte, 0) + } + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} diff --git a/credentials/alts/core/conn/aes128gcm.go b/credentials/alts/core/conn/aes128gcm.go new file mode 100644 index 00000000..0c4fe339 --- /dev/null +++ b/credentials/alts/core/conn/aes128gcm.go @@ -0,0 +1,105 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "crypto/aes" + "crypto/cipher" + + "google.golang.org/grpc/credentials/alts/core" +) + +const ( + // Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in + // each direction). + overflowLenAES128GCM = 5 +) + +// aes128gcm is the struct that holds necessary information for ALTS record. +// The counter value is NOT included in the payload during the encryption and +// decryption operations. +type aes128gcm struct { + // inCounter is used in ALTS record to check that incoming counters are + // as expected, since ALTS record guarantees that messages are unwrapped + // in the same order that the peer wrapped them. + inCounter counter + outCounter counter + aead cipher.AEAD +} + +// NewAES128GCM creates an instance that uses aes128gcm for ALTS record. +func NewAES128GCM(side core.Side, key []byte) (ALTSRecordCrypto, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + a, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + return &aes128gcm{ + inCounter: newInCounter(side, overflowLenAES128GCM), + outCounter: newOutCounter(side, overflowLenAES128GCM), + aead: a, + }, nil +} + +// Encrypt is the encryption function. dst can contain bytes at the beginning of +// the ciphertext that will not be encrypted but will be authenticated. If dst +// has enough capacity to hold these bytes, the ciphertext and the tag, no +// allocation and copy operations will be performed. dst and plaintext do not +// overlap. +func (s *aes128gcm) Encrypt(dst, plaintext []byte) ([]byte, error) { + // If we need to allocate an output buffer, we want to include space for + // GCM tag to avoid forcing ALTS record to reallocate as well. + dlen := len(dst) + dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize) + seq, err := s.outCounter.Value() + if err != nil { + return nil, err + } + data := out[:len(plaintext)] + copy(data, plaintext) // data may alias plaintext + + // Seal appends the ciphertext and the tag to its first argument and + // returns the updated slice. However, SliceForAppend above ensures that + // dst has enough capacity to avoid a reallocation and copy due to the + // append. + dst = s.aead.Seal(dst[:dlen], seq, data, nil) + s.outCounter.Inc() + return dst, nil +} + +func (s *aes128gcm) EncryptionOverhead() int { + return GcmTagSize +} + +func (s *aes128gcm) Decrypt(dst, ciphertext []byte) ([]byte, error) { + seq, err := s.inCounter.Value() + if err != nil { + return nil, err + } + // If dst is equal to ciphertext[:0], ciphertext storage is reused. + plaintext, err := s.aead.Open(dst, seq, ciphertext, nil) + if err != nil { + return nil, ErrAuth + } + s.inCounter.Inc() + return plaintext, nil +} diff --git a/credentials/alts/core/conn/aes128gcm_test.go b/credentials/alts/core/conn/aes128gcm_test.go new file mode 100644 index 00000000..c2fca4dd --- /dev/null +++ b/credentials/alts/core/conn/aes128gcm_test.go @@ -0,0 +1,223 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "bytes" + "testing" + + "google.golang.org/grpc/credentials/alts/core" +) + +// cryptoTestVector is struct for a GCM test vector +type cryptoTestVector struct { + key, counter, plaintext, ciphertext, tag []byte + allocateDst bool +} + +// getGCMCryptoPair outputs a client/server pair on aes128gcm. +func getGCMCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) { + client, err := NewAES128GCM(core.ClientSide, key) + if err != nil { + t.Fatalf("NewAES128GCM(ClientSide, key) = %v", err) + } + server, err := NewAES128GCM(core.ServerSide, key) + if err != nil { + t.Fatalf("NewAES128GCM(ServerSide, key) = %v", err) + } + // set counter if provided. + if counter != nil { + if counterSide(counter) == core.ClientSide { + client.(*aes128gcm).outCounter = counterFromValue(counter, overflowLenAES128GCM) + server.(*aes128gcm).inCounter = counterFromValue(counter, overflowLenAES128GCM) + } else { + server.(*aes128gcm).outCounter = counterFromValue(counter, overflowLenAES128GCM) + client.(*aes128gcm).inCounter = counterFromValue(counter, overflowLenAES128GCM) + } + } + return client, server +} + +func testGCMEncryptionDecryption(sender ALTSRecordCrypto, receiver ALTSRecordCrypto, test *cryptoTestVector, withCounter bool, t *testing.T) { + // Ciphertext is: counter + encrypted text + tag. + ciphertext := []byte(nil) + if withCounter { + ciphertext = append(ciphertext, test.counter...) + } + ciphertext = append(ciphertext, test.ciphertext...) + ciphertext = append(ciphertext, test.tag...) + + // Decrypt. + if got, err := receiver.Decrypt(nil, ciphertext); err != nil || !bytes.Equal(got, test.plaintext) { + t.Errorf("key=%v\ncounter=%v\ntag=%v\nciphertext=%v\nDecrypt = %v, %v\nwant: %v", + test.key, test.counter, test.tag, test.ciphertext, got, err, test.plaintext) + } + + // Encrypt. + var dst []byte + if test.allocateDst { + dst = make([]byte, len(test.plaintext)+sender.EncryptionOverhead()) + } + if got, err := sender.Encrypt(dst[:0], test.plaintext); err != nil || !bytes.Equal(got, ciphertext) { + t.Errorf("key=%v\ncounter=%v\nplaintext=%v\nEncrypt = %v, %v\nwant: %v", + test.key, test.counter, test.plaintext, got, err, ciphertext) + } +} + +// Test encrypt and decrypt using test vectors for aes128gcm. +func TestAES128GCMEncrypt(t *testing.T) { + for _, test := range []cryptoTestVector{ + { + key: dehex("11754cd72aec309bf52f7687212e8957"), + counter: dehex("3c819d9a9bed087615030b65"), + plaintext: nil, + ciphertext: nil, + tag: dehex("250327c674aaf477aef2675748cf6971"), + allocateDst: false, + }, + { + key: dehex("ca47248ac0b6f8372a97ac43508308ed"), + counter: dehex("ffd2b598feabc9019262d2be"), + plaintext: nil, + ciphertext: nil, + tag: dehex("60d20404af527d248d893ae495707d1a"), + allocateDst: false, + }, + { + key: dehex("7fddb57453c241d03efbed3ac44e371c"), + counter: dehex("ee283a3fc75575e33efd4887"), + plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"), + ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"), + tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"), + allocateDst: false, + }, + { + key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"), + counter: dehex("54cc7dc2c37ec006bcc6d1da"), + plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"), + ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"), + tag: dehex("2b4401346697138c7a4891ee59867d0c"), + allocateDst: false, + }, + { + key: dehex("11754cd72aec309bf52f7687212e8957"), + counter: dehex("3c819d9a9bed087615030b65"), + plaintext: nil, + ciphertext: nil, + tag: dehex("250327c674aaf477aef2675748cf6971"), + allocateDst: true, + }, + { + key: dehex("ca47248ac0b6f8372a97ac43508308ed"), + counter: dehex("ffd2b598feabc9019262d2be"), + plaintext: nil, + ciphertext: nil, + tag: dehex("60d20404af527d248d893ae495707d1a"), + allocateDst: true, + }, + { + key: dehex("7fddb57453c241d03efbed3ac44e371c"), + counter: dehex("ee283a3fc75575e33efd4887"), + plaintext: dehex("d5de42b461646c255c87bd2962d3b9a2"), + ciphertext: dehex("2ccda4a5415cb91e135c2a0f78c9b2fd"), + tag: dehex("b36d1df9b9d5e596f83e8b7f52971cb3"), + allocateDst: true, + }, + { + key: dehex("ab72c77b97cb5fe9a382d9fe81ffdbed"), + counter: dehex("54cc7dc2c37ec006bcc6d1da"), + plaintext: dehex("007c5e5b3e59df24a7c355584fc1518d"), + ciphertext: dehex("0e1bde206a07a9c2c1b65300f8c64997"), + tag: dehex("2b4401346697138c7a4891ee59867d0c"), + allocateDst: true, + }, + } { + // Test encryption and decryption for aes128gcm. + client, server := getGCMCryptoPair(test.key, test.counter, t) + if counterSide(test.counter) == core.ClientSide { + testGCMEncryptionDecryption(client, server, &test, false, t) + } else { + testGCMEncryptionDecryption(server, client, &test, false, t) + } + } +} + +func testGCMEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) { + // Encrypt. + const plaintext = "This is plaintext." + var err error + buf := []byte(plaintext) + buf, err = client.Encrypt(buf[:0], buf) + if err != nil { + t.Fatal("Encrypting with client-side context: unexpected error", err, "\n", + "Plaintext:", []byte(plaintext)) + } + + // Encrypt a second message. + const plaintext2 = "This is a second plaintext." + buf2 := []byte(plaintext2) + buf2, err = client.Encrypt(buf2[:0], buf2) + if err != nil { + t.Fatal("Encrypting with client-side context: unexpected error", err, "\n", + "Plaintext:", []byte(plaintext2)) + } + + // Decryption fails: cannot decrypt second message before first. + if got, err := server.Decrypt(nil, buf2); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n", + " Original plaintext:", []byte(plaintext2), "\n", + " Ciphertext:", buf2, "\n", + " Decrypted plaintext:", got) + } + + // Decryption fails: wrong counter space. + if got, err := client.Decrypt(nil, buf); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", buf, "\n", + " Decrypted plaintext:", got) + } + + // Decrypt first message. + ciphertext := append([]byte(nil), buf...) + buf, err = server.Decrypt(buf[:0], buf) + if err != nil || string(buf) != plaintext { + t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", ciphertext, "\n", + " Decryption error:", err, "\n", + " Decrypted plaintext:", buf) + } + + // Decryption fails: replay attack. + if got, err := server.Decrypt(nil, buf); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", buf, "\n", + " Decrypted plaintext:", got) + } +} + +// Test encrypt and decrypt on roundtrip messages for aes128gcm. +func TestAES128GCMEncryptRoundtrip(t *testing.T) { + // Test for aes128gcm. + key := make([]byte, 16) + client, server := getGCMCryptoPair(key, nil, t) + testGCMEncryptRoundtrip(client, server, t) +} diff --git a/credentials/alts/core/conn/aes128gcmrekey.go b/credentials/alts/core/conn/aes128gcmrekey.go new file mode 100644 index 00000000..4b5ecf40 --- /dev/null +++ b/credentials/alts/core/conn/aes128gcmrekey.go @@ -0,0 +1,116 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "crypto/cipher" + + "google.golang.org/grpc/credentials/alts/core" +) + +const ( + // Overflow length n in bytes, never encrypt more than 2^(n*8) frames (in + // each direction). + overflowLenAES128GCMRekey = 8 + nonceLen = 12 + aeadKeyLen = 16 + kdfKeyLen = 32 + kdfCounterOffset = 2 + kdfCounterLen = 6 + sizeUint64 = 8 +) + +// aes128gcmRekey is the struct that holds necessary information for ALTS record. +// The counter value is NOT included in the payload during the encryption and +// decryption operations. +type aes128gcmRekey struct { + // inCounter is used in ALTS record to check that incoming counters are + // as expected, since ALTS record guarantees that messages are unwrapped + // in the same order that the peer wrapped them. + inCounter counter + outCounter counter + inAEAD cipher.AEAD + outAEAD cipher.AEAD +} + +// NewAES128GCMRekey creates an instance that uses aes128gcm with rekeying +// for ALTS record. The key argument should be 44 bytes, the first 32 bytes +// are used as a key for HKDF-expand and the remainining 12 bytes are used +// as a random mask for the counter. +func NewAES128GCMRekey(side core.Side, key []byte) (ALTSRecordCrypto, error) { + inCounter := newInCounter(side, overflowLenAES128GCMRekey) + outCounter := newOutCounter(side, overflowLenAES128GCMRekey) + inAEAD, err := newRekeyAEAD(key) + if err != nil { + return nil, err + } + outAEAD, err := newRekeyAEAD(key) + if err != nil { + return nil, err + } + return &aes128gcmRekey{ + inCounter, + outCounter, + inAEAD, + outAEAD, + }, nil +} + +// Encrypt is the encryption function. dst can contain bytes at the beginning of +// the ciphertext that will not be encrypted but will be authenticated. If dst +// has enough capacity to hold these bytes, the ciphertext and the tag, no +// allocation and copy operations will be performed. dst and plaintext do not +// overlap. +func (s *aes128gcmRekey) Encrypt(dst, plaintext []byte) ([]byte, error) { + // If we need to allocate an output buffer, we want to include space for + // GCM tag to avoid forcing ALTS record to reallocate as well. + dlen := len(dst) + dst, out := SliceForAppend(dst, len(plaintext)+GcmTagSize) + seq, err := s.outCounter.Value() + if err != nil { + return nil, err + } + data := out[:len(plaintext)] + copy(data, plaintext) // data may alias plaintext + + // Seal appends the ciphertext and the tag to its first argument and + // returns the updated slice. However, SliceForAppend above ensures that + // dst has enough capacity to avoid a reallocation and copy due to the + // append. + dst = s.outAEAD.Seal(dst[:dlen], seq, data, nil) + s.outCounter.Inc() + return dst, nil +} + +func (s *aes128gcmRekey) EncryptionOverhead() int { + return GcmTagSize +} + +func (s *aes128gcmRekey) Decrypt(dst, ciphertext []byte) ([]byte, error) { + seq, err := s.inCounter.Value() + if err != nil { + return nil, err + } + plaintext, err := s.inAEAD.Open(dst, seq, ciphertext, nil) + if err != nil { + return nil, ErrAuth + } + s.inCounter.Inc() + return plaintext, nil +} diff --git a/credentials/alts/core/conn/aes128gcmrekey_test.go b/credentials/alts/core/conn/aes128gcmrekey_test.go new file mode 100644 index 00000000..9f31a0fa --- /dev/null +++ b/credentials/alts/core/conn/aes128gcmrekey_test.go @@ -0,0 +1,117 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "testing" + + "google.golang.org/grpc/credentials/alts/core" +) + +// cryptoTestVector is struct for a rekey test vector +type rekeyTestVector struct { + key, nonce, plaintext, ciphertext []byte +} + +// getGCMCryptoPair outputs a client/server pair on aes128gcmRekey. +func getRekeyCryptoPair(key []byte, counter []byte, t *testing.T) (ALTSRecordCrypto, ALTSRecordCrypto) { + client, err := NewAES128GCMRekey(core.ClientSide, key) + if err != nil { + t.Fatalf("NewAES128GCMRekey(ClientSide, key) = %v", err) + } + server, err := NewAES128GCMRekey(core.ServerSide, key) + if err != nil { + t.Fatalf("NewAES128GCMRekey(ServerSide, key) = %v", err) + } + // set counter if provided. + if counter != nil { + if counterSide(counter) == core.ClientSide { + client.(*aes128gcmRekey).outCounter = counterFromValue(counter, overflowLenAES128GCMRekey) + server.(*aes128gcmRekey).inCounter = counterFromValue(counter, overflowLenAES128GCMRekey) + } else { + server.(*aes128gcmRekey).outCounter = counterFromValue(counter, overflowLenAES128GCMRekey) + client.(*aes128gcmRekey).inCounter = counterFromValue(counter, overflowLenAES128GCMRekey) + } + } + return client, server +} + +func testRekeyEncryptRoundtrip(client ALTSRecordCrypto, server ALTSRecordCrypto, t *testing.T) { + // Encrypt. + const plaintext = "This is plaintext." + var err error + buf := []byte(plaintext) + buf, err = client.Encrypt(buf[:0], buf) + if err != nil { + t.Fatal("Encrypting with client-side context: unexpected error", err, "\n", + "Plaintext:", []byte(plaintext)) + } + + // Encrypt a second message. + const plaintext2 = "This is a second plaintext." + buf2 := []byte(plaintext2) + buf2, err = client.Encrypt(buf2[:0], buf2) + if err != nil { + t.Fatal("Encrypting with client-side context: unexpected error", err, "\n", + "Plaintext:", []byte(plaintext2)) + } + + // Decryption fails: cannot decrypt second message before first. + if got, err := server.Decrypt(nil, buf2); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n", + " Original plaintext:", []byte(plaintext2), "\n", + " Ciphertext:", buf2, "\n", + " Decrypted plaintext:", got) + } + + // Decryption fails: wrong counter space. + if got, err := client.Decrypt(nil, buf); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want counter space error:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", buf, "\n", + " Decrypted plaintext:", got) + } + + // Decrypt first message. + ciphertext := append([]byte(nil), buf...) + buf, err = server.Decrypt(buf[:0], buf) + if err != nil || string(buf) != plaintext { + t.Fatal("Decrypting client-side ciphertext with a server-side context did not produce original content:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", ciphertext, "\n", + " Decryption error:", err, "\n", + " Decrypted plaintext:", buf) + } + + // Decryption fails: replay attack. + if got, err := server.Decrypt(nil, buf); err == nil { + t.Error("Decrypting client-side ciphertext with a client-side context unexpectedly succeeded; want unexpected counter error:\n", + " Original plaintext:", []byte(plaintext), "\n", + " Ciphertext:", buf, "\n", + " Decrypted plaintext:", got) + } +} + +// Test encrypt and decrypt on roundtrip messages for aes128gcmRekey. +func TestAES128GCMRekeyEncryptRoundtrip(t *testing.T) { + // Test for aes128gcmRekey. + key := make([]byte, 44) + client, server := getRekeyCryptoPair(key, nil, t) + testRekeyEncryptRoundtrip(client, server, t) +} diff --git a/credentials/alts/core/conn/common.go b/credentials/alts/core/conn/common.go new file mode 100644 index 00000000..1795d0c9 --- /dev/null +++ b/credentials/alts/core/conn/common.go @@ -0,0 +1,70 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "encoding/binary" + "errors" + "fmt" +) + +const ( + // GcmTagSize is the GCM tag size is the difference in length between + // plaintext and ciphertext. From crypto/cipher/gcm.go in Go crypto + // library. + GcmTagSize = 16 +) + +// ErrAuth occurs on authentication failure. +var ErrAuth = errors.New("message authentication failed") + +// SliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func SliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return head, tail +} + +// ParseFramedMsg parse the provided buffer and returns a frame of the format +// msgLength+msg and any remaining bytes in that buffer. +func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) { + // If the size field is not complete, return the provided buffer as + // remaining buffer. + if len(b) < MsgLenFieldSize { + return nil, b, nil + } + msgLenField := b[:MsgLenFieldSize] + length := binary.LittleEndian.Uint32(msgLenField) + if length > maxLen { + return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) + } + if len(b) < int(length)+4 { // account for the first 4 msg length bytes. + // Frame is not complete yet. + return nil, b, nil + } + return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil +} diff --git a/credentials/alts/core/conn/counter.go b/credentials/alts/core/conn/counter.go new file mode 100644 index 00000000..754dcfaa --- /dev/null +++ b/credentials/alts/core/conn/counter.go @@ -0,0 +1,106 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "errors" + + "google.golang.org/grpc/credentials/alts/core" +) + +const counterLen = 12 + +var ( + errInvalidCounter = errors.New("invalid counter") +) + +// counter is a 96-bit, little-endian counter. +type counter struct { + value [counterLen]byte + invalid bool + overflowLen int +} + +// newOutCounter returns an outgoing counter initialized to the starting sequence +// number for the client/server side of a connection. +func newOutCounter(s core.Side, overflowLen int) (c counter) { + c.overflowLen = overflowLen + if s == core.ServerSide { + // Server counters in ALTS record have the little-endian high bit + // set. + c.value[counterLen-1] = 0x80 + } + return +} + +// newInCounter returns an incoming counter initialized to the starting sequence +// number for the client/server side of a connection. This is used in ALTS record +// to check that incoming counters are as expected, since ALTS record guarantees +// that messages are unwrapped in the same order that the peer wrapped them. +func newInCounter(s core.Side, overflowLen int) (c counter) { + c.overflowLen = overflowLen + if s == core.ClientSide { + // Server counters in ALTS record have the little-endian high bit + // set. + c.value[counterLen-1] = 0x80 + } + return +} + +// counterFromValue creates a new counter given an initial value. +func counterFromValue(value []byte, overflowLen int) (c counter) { + c.overflowLen = overflowLen + copy(c.value[:], value) + return +} + +// Value returns the current value of the counter as a byte slice. +func (c *counter) Value() ([]byte, error) { + if c.invalid { + return nil, errInvalidCounter + } + return c.value[:], nil +} + +// Inc increments the counter and checks for overflow. +func (c *counter) Inc() { + // If the counter is already invalid, there is not need to increase it. + if c.invalid { + return + } + i := 0 + for ; i < c.overflowLen; i++ { + c.value[i]++ + if c.value[i] != 0 { + break + } + } + if i == c.overflowLen { + c.invalid = true + } +} + +// counterSide returns the connection side (client/server) a sequence counter is +// associated with. +func counterSide(c []byte) core.Side { + if c[counterLen-1]&0x80 == 0x80 { + return core.ServerSide + } + return core.ClientSide +} diff --git a/credentials/alts/core/conn/counter_test.go b/credentials/alts/core/conn/counter_test.go new file mode 100644 index 00000000..faf56369 --- /dev/null +++ b/credentials/alts/core/conn/counter_test.go @@ -0,0 +1,141 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "bytes" + "testing" + + "google.golang.org/grpc/credentials/alts/core" +) + +const ( + testOverflowLen = 5 +) + +func TestCounterSides(t *testing.T) { + for _, side := range []core.Side{core.ClientSide, core.ServerSide} { + outCounter := newOutCounter(side, testOverflowLen) + inCounter := newInCounter(side, testOverflowLen) + for i := 0; i < 1024; i++ { + value, _ := outCounter.Value() + if g, w := counterSide(value), side; g != w { + t.Errorf("after %d iterations, counterSide(outCounter.Value()) = %v, want %v", i, g, w) + break + } + value, _ = inCounter.Value() + if g, w := counterSide(value), side; g == w { + t.Errorf("after %d iterations, counterSide(inCounter.Value()) = %v, want %v", i, g, w) + break + } + outCounter.Inc() + inCounter.Inc() + } + } +} + +func TestCounterInc(t *testing.T) { + for _, test := range []struct { + counter []byte + want []byte + }{ + { + counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + counter: []byte{0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80}, + want: []byte{0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x80}, + }, + { + counter: []byte{0xff, 0x00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + want: []byte{0x00, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + counter: []byte{0x42, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + want: []byte{0x43, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + counter: []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + }, + } { + c := counterFromValue(test.counter, overflowLenAES128GCM) + c.Inc() + value, _ := c.Value() + if g, w := value, test.want; !bytes.Equal(g, w) || c.invalid { + t.Errorf("counter(%v).Inc() =\n%v, want\n%v", test.counter, g, w) + } + } +} + +func TestRolloverCounter(t *testing.T) { + for _, test := range []struct { + desc string + value []byte + overflowLen int + }{ + { + desc: "testing overflow without rekeying 1", + value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}, + overflowLen: 5, + }, + { + desc: "testing overflow without rekeying 2", + value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + overflowLen: 5, + }, + { + desc: "testing overflow for rekeying mode 1", + value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x80}, + overflowLen: 8, + }, + { + desc: "testing overflow for rekeying mode 2", + value: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00}, + overflowLen: 8, + }, + } { + c := counterFromValue(test.value, overflowLenAES128GCM) + + // First Inc() + Value() should work. + c.Inc() + _, err := c.Value() + if err != nil { + t.Errorf("%v: first Inc() + Value() unexpectedly failed: %v, want error", test.desc, err) + } + // Second Inc() + Value() should fail. + c.Inc() + _, err = c.Value() + if err != errInvalidCounter { + t.Errorf("%v: second Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter) + } + // Third Inc() + Value() should also fail because the counter is + // already in an invalid state. + c.Inc() + _, err = c.Value() + if err != errInvalidCounter { + t.Errorf("%v: Third Inc() + Value() unexpectedly succeeded: want %v", test.desc, errInvalidCounter) + } + } +} diff --git a/credentials/alts/core/conn/record.go b/credentials/alts/core/conn/record.go new file mode 100644 index 00000000..cb514ebe --- /dev/null +++ b/credentials/alts/core/conn/record.go @@ -0,0 +1,271 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package conn contains an implementation of a secure channel created by gRPC +// handshakers. +package conn + +import ( + "encoding/binary" + "fmt" + "math" + "net" + + "google.golang.org/grpc/credentials/alts/core" +) + +// ALTSRecordCrypto is the interface for gRPC ALTS record protocol. +type ALTSRecordCrypto interface { + // Encrypt encrypts the plaintext and computes the tag (if any) of dst + // and plaintext, dst and plaintext do not overlap. + Encrypt(dst, plaintext []byte) ([]byte, error) + // EncryptionOverhead returns the tag size (if any) in bytes. + EncryptionOverhead() int + // Decrypt decrypts ciphertext and verify the tag (if any). dst and + // ciphertext may alias exactly or not at all. To reuse ciphertext's + // storage for the decrypted output, use ciphertext[:0] as dst. + Decrypt(dst, ciphertext []byte) ([]byte, error) +} + +// ALTSRecordFunc is a function type for factory functions that create +// ALTSRecordCrypto instances. +type ALTSRecordFunc func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) + +const ( + // MsgLenFieldSize is the byte size of the frame length field of a + // framed message. + MsgLenFieldSize = 4 + // The byte size of the message type field of a framed message. + msgTypeFieldSize = 4 + // The bytes size limit for a ALTS record message. + altsRecordLengthLimit = 1024 * 1024 // 1 MiB + // The default bytes size of a ALTS record message. + altsRecordDefaultLength = 4 * 1024 // 4KiB + // Message type value included in ALTS record framing. + altsRecordMsgType = uint32(0x06) + // The initial write buffer size. + altsWriteBufferInitialSize = 32 * 1024 // 32KiB + // The maximum write buffer size. This *must* be multiple of + // altsRecordDefaultLength. + altsWriteBufferMaxSize = 512 * 1024 // 512KiB +) + +var ( + protocols = make(map[string]ALTSRecordFunc) +) + +// RegisterProtocol register a ALTS record encryption protocol. +func RegisterProtocol(protocol string, f ALTSRecordFunc) error { + if _, ok := protocols[protocol]; ok { + return fmt.Errorf("protocol %v is already registered", protocol) + } + protocols[protocol] = f + return nil +} + +// conn represents a secured connection. It implements the net.Conn interface. +type conn struct { + net.Conn + crypto ALTSRecordCrypto + // buf holds data that has been read from the connection and decrypted, + // but has not yet been returned by Read. + buf []byte + payloadLengthLimit int + // protected holds data read from the network but have not yet been + // decrypted. This data might not compose a complete frame. + protected []byte + // writeBuf is a buffer used to contain encrypted frames before being + // written to the network. + writeBuf []byte + // nextFrame stores the next frame (in protected buffer) info. + nextFrame []byte + // overhead is the calculated overhead of each frame. + overhead int +} + +// NewConn creates a new secure channel instance given the other party role and +// handshaking result. +func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) { + newCrypto := protocols[recordProtocol] + if newCrypto == nil { + return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol) + } + crypto, err := newCrypto(side, key) + if err != nil { + return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err) + } + overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() + payloadLengthLimit := altsRecordDefaultLength - overhead + if protected == nil { + // We pre-allocate protected to be of size + // 2*altsRecordDefaultLength-1 during initialization. We only + // read from the network into protected when protected does not + // contain a complete frame, which is at most + // altsRecordDefaultLength-1 (bytes). And we read at most + // altsRecordDefaultLength (bytes) data into protected at one + // time. Therefore, 2*altsRecordDefaultLength-1 is large enough + // to buffer data read from the network. + protected = make([]byte, 0, 2*altsRecordDefaultLength-1) + } + + altsConn := &conn{ + Conn: c, + crypto: crypto, + payloadLengthLimit: payloadLengthLimit, + protected: protected, + writeBuf: make([]byte, altsWriteBufferInitialSize), + nextFrame: protected, + overhead: overhead, + } + return altsConn, nil +} + +// Read reads and decrypts a frame from the underlying connection, and copies the +// decrypted payload into b. If the size of the payload is greater than len(b), +// Read retains the remaining bytes in an internal buffer, and subsequent calls +// to Read will read from this buffer until it is exhausted. +func (p *conn) Read(b []byte) (n int, err error) { + if len(p.buf) == 0 { + var framedMsg []byte + framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit) + if err != nil { + return n, err + } + // Check whether the next frame to be decrypted has been + // completely received yet. + if len(framedMsg) == 0 { + copy(p.protected, p.nextFrame) + p.protected = p.protected[:len(p.nextFrame)] + // Always copy next incomplete frame to the beginning of + // the protected buffer and reset nextFrame to it. + p.nextFrame = p.protected + } + // Check whether a complete frame has been received yet. + for len(framedMsg) == 0 { + if len(p.protected) == cap(p.protected) { + tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength) + copy(tmp, p.protected) + p.protected = tmp + } + n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)]) + if err != nil { + return 0, err + } + p.protected = p.protected[:len(p.protected)+n] + framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit) + if err != nil { + return 0, err + } + } + // Now we have a complete frame, decrypted it. + msg := framedMsg[MsgLenFieldSize:] + msgType := binary.LittleEndian.Uint32(msg[:msgTypeFieldSize]) + if msgType&0xff != altsRecordMsgType { + return 0, fmt.Errorf("received frame with incorrect message type %v, expected lower byte %v", + msgType, altsRecordMsgType) + } + ciphertext := msg[msgTypeFieldSize:] + + // Decrypt requires that if the dst and ciphertext alias, they + // must alias exactly. Code here used to use msg[:0], but msg + // starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than + // ciphertext, so they alias inexactly. Using ciphertext[:0] + // arranges the appropriate aliasing without needing to copy + // ciphertext or use a separate destination buffer. For more info + // check: https://golang.org/pkg/crypto/cipher/#AEAD. + p.buf, err = p.crypto.Decrypt(ciphertext[:0], ciphertext) + if err != nil { + return 0, err + } + } + + n = copy(b, p.buf) + p.buf = p.buf[n:] + return n, nil +} + +// Write encrypts, frames, and writes bytes from b to the underlying connection. +func (p *conn) Write(b []byte) (n int, err error) { + n = len(b) + // Calculate the output buffer size with framing and encryption overhead. + numOfFrames := int(math.Ceil(float64(len(b)) / float64(p.payloadLengthLimit))) + size := len(b) + numOfFrames*p.overhead + // If writeBuf is too small, increase its size up to the maximum size. + partialBSize := len(b) + if size > altsWriteBufferMaxSize { + size = altsWriteBufferMaxSize + const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength + partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit + } + if len(p.writeBuf) < size { + p.writeBuf = make([]byte, size) + } + + for partialBStart := 0; partialBStart < len(b); partialBStart += partialBSize { + partialBEnd := partialBStart + partialBSize + if partialBEnd > len(b) { + partialBEnd = len(b) + } + partialB := b[partialBStart:partialBEnd] + writeBufIndex := 0 + for len(partialB) > 0 { + payloadLen := len(partialB) + if payloadLen > p.payloadLengthLimit { + payloadLen = p.payloadLengthLimit + } + buf := partialB[:payloadLen] + partialB = partialB[payloadLen:] + + // Write buffer contains: length, type, payload, and tag + // if any. + + // 1. Fill in type field. + msg := p.writeBuf[writeBufIndex+MsgLenFieldSize:] + binary.LittleEndian.PutUint32(msg, altsRecordMsgType) + + // 2. Encrypt the payload and create a tag if any. + msg, err = p.crypto.Encrypt(msg[:msgTypeFieldSize], buf) + if err != nil { + return n, err + } + + // 3. Fill in the size field. + binary.LittleEndian.PutUint32(p.writeBuf[writeBufIndex:], uint32(len(msg))) + + // 4. Increase writeBufIndex. + writeBufIndex += len(buf) + p.overhead + } + nn, err := p.Conn.Write(p.writeBuf[:writeBufIndex]) + if err != nil { + // We need to calculate the actual data size that was + // written. This means we need to remove header, + // encryption overheads, and any partially-written + // frame data. + numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength))) + return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err + } + } + return n, nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/credentials/alts/core/conn/record_test.go b/credentials/alts/core/conn/record_test.go new file mode 100644 index 00000000..5c3b8e2c --- /dev/null +++ b/credentials/alts/core/conn/record_test.go @@ -0,0 +1,274 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "net" + "reflect" + "testing" + + "google.golang.org/grpc/credentials/alts/core" +) + +var ( + nextProtocols = []string{"ALTSRP_GCM_AES128"} + altsRecordFuncs = map[string]ALTSRecordFunc{ + // ALTS handshaker protocols. + "ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) { + return NewAES128GCM(s, keyData) + }, + } +) + +func init() { + for protocol, f := range altsRecordFuncs { + if err := RegisterProtocol(protocol, f); err != nil { + panic(err) + } + } +} + +// testConn mimics a net.Conn to the peer. +type testConn struct { + net.Conn + in *bytes.Buffer + out *bytes.Buffer +} + +func (c *testConn) Read(b []byte) (n int, err error) { + return c.in.Read(b) +} + +func (c *testConn) Write(b []byte) (n int, err error) { + return c.out.Write(b) +} + +func (c *testConn) Close() error { + return nil +} + +func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn { + key := []byte{ + // 16 arbitrary bytes. + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} + tc := testConn{ + in: in, + out: out, + } + c, err := NewConn(&tc, side, np, key, nil) + if err != nil { + panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) + } + return c.(*conn) +} + +func newConnPair(np string) (client, server *conn) { + clientBuf := new(bytes.Buffer) + serverBuf := new(bytes.Buffer) + clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np) + serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np) + return clientConn, serverConn +} + +func testPingPong(t *testing.T, np string) { + clientConn, serverConn := newConnPair(np) + clientMsg := []byte("Client Message") + if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { + t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) + } + rcvClientMsg := make([]byte, len(clientMsg)) + if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { + t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(rcvClientMsg)) + } + if !reflect.DeepEqual(clientMsg, rcvClientMsg) { + t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) + } + + serverMsg := []byte("Server Message") + if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil { + t.Fatalf("Server Write() = %v, %v; want %v, ", n, err, len(serverMsg)) + } + rcvServerMsg := make([]byte, len(serverMsg)) + if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil { + t.Fatalf("Client Read() = %v, %v; want %v, ", n, err, len(rcvServerMsg)) + } + if !reflect.DeepEqual(serverMsg, rcvServerMsg) { + t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg) + } +} + +func TestPingPong(t *testing.T) { + for _, np := range nextProtocols { + testPingPong(t, np) + } +} + +func testSmallReadBuffer(t *testing.T, np string) { + clientConn, serverConn := newConnPair(np) + msg := []byte("Very Important Message") + if n, err := clientConn.Write(msg); err != nil { + t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) + } + rcvMsg := make([]byte, len(msg)) + n := 2 // Arbitrary index to break rcvMsg in two. + rcvMsg1 := rcvMsg[:n] + rcvMsg2 := rcvMsg[n:] + if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil { + t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg1)) + } + if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil { + t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg2)) + } + if !reflect.DeepEqual(msg, rcvMsg) { + t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg) + } +} + +func TestSmallReadBuffer(t *testing.T) { + for _, np := range nextProtocols { + testSmallReadBuffer(t, np) + } +} + +func testLargeMsg(t *testing.T, np string) { + clientConn, serverConn := newConnPair(np) + // msgLen is such that the length in the framing is larger than the + // default size of one frame. + msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 + msg := make([]byte, msgLen) + if n, err := clientConn.Write(msg); n != len(msg) || err != nil { + t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) + } + rcvMsg := make([]byte, len(msg)) + if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil { + t.Fatalf("Read() = %v, %v; want %v, ", n, err, len(rcvMsg)) + } + if !reflect.DeepEqual(msg, rcvMsg) { + t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg) + } +} + +func TestLargeMsg(t *testing.T) { + for _, np := range nextProtocols { + testLargeMsg(t, np) + } +} + +func testIncorrectMsgType(t *testing.T, np string) { + // framedMsg is an empty ciphertext with correct framing but wrong + // message type. + framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize) + binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize) + wrongMsgType := uint32(0x22) + binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) + + in := bytes.NewBuffer(framedMsg) + c := newTestALTSRecordConn(in, nil, core.ClientSide, np) + b := make([]byte, 1) + if n, err := c.Read(b); n != 0 || err == nil { + t.Fatalf("Read() = , want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) + } +} + +func TestIncorrectMsgType(t *testing.T) { + for _, np := range nextProtocols { + testIncorrectMsgType(t, np) + } +} + +func testFrameTooLarge(t *testing.T, np string) { + buf := new(bytes.Buffer) + clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np) + serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np) + // payloadLen is such that the length in the framing is larger than + // allowed in one frame. + payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 + payload := make([]byte, payloadLen) + c, err := clientConn.crypto.Encrypt(nil, payload) + if err != nil { + t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err)) + } + msgLen := msgTypeFieldSize + len(c) + framedMsg := make([]byte, MsgLenFieldSize+msgLen) + binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c))) + msg := framedMsg[MsgLenFieldSize:] + binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType) + copy(msg[msgTypeFieldSize:], c) + if _, err = buf.Write(framedMsg); err != nil { + t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err)) + } + b := make([]byte, 1) + if n, err := serverConn.Read(b); n != 0 || err == nil { + t.Fatalf("Read() = , want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit)) + } +} + +func TestFrameTooLarge(t *testing.T) { + for _, np := range nextProtocols { + testFrameTooLarge(t, np) + } +} + +func testWriteLargeData(t *testing.T, np string) { + // Test sending and receiving messages larger than the maximum write + // buffer size. + clientConn, serverConn := newConnPair(np) + // Message size is intentionally chosen to not be multiple of + // payloadLengthLimtit. + msgSize := altsWriteBufferMaxSize + (100 * 1024) + clientMsg := make([]byte, msgSize) + for i := 0; i < msgSize; i++ { + clientMsg[i] = 0xAA + } + if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { + t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) + } + // We need to keep reading until the entire message is received. The + // reason we set all bytes of the message to a value other than zero is + // to avoid ambiguous zero-init value of rcvClientMsg buffer and the + // actual received data. + rcvClientMsg := make([]byte, 0, msgSize) + numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit))) + for i := 0; i < numberOfExpectedFrames; i++ { + expectedRcvSize := serverConn.payloadLengthLimit + if i == numberOfExpectedFrames-1 { + // Last frame might be smaller. + expectedRcvSize = msgSize % serverConn.payloadLengthLimit + } + tmpBuf := make([]byte, expectedRcvSize) + if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil { + t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(tmpBuf)) + } + rcvClientMsg = append(rcvClientMsg, tmpBuf...) + } + if !reflect.DeepEqual(clientMsg, rcvClientMsg) { + t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg) + } +} + +func TestWriteLargeData(t *testing.T) { + for _, np := range nextProtocols { + testWriteLargeData(t, np) + } +} diff --git a/credentials/alts/core/handshaker/handshaker.go b/credentials/alts/core/handshaker/handshaker.go new file mode 100644 index 00000000..ac5a3385 --- /dev/null +++ b/credentials/alts/core/handshaker/handshaker.go @@ -0,0 +1,364 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package handshaker provides ALTS handshaking functionality for GCP. +package handshaker + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + + "golang.org/x/net/context" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/alts/core" + "google.golang.org/grpc/credentials/alts/core/authinfo" + "google.golang.org/grpc/credentials/alts/core/conn" + altspb "google.golang.org/grpc/credentials/alts/core/proto" +) + +const ( + // The maximum byte size of receive frames. + frameLimit = 64 * 1024 // 64 KB + rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY" + // maxPendingHandshakes represents the maximum number of concurrent + // handshakes. + maxPendingHandshakes = 100 +) + +var ( + hsProtocol = altspb.HandshakeProtocol_ALTS + appProtocols = []string{"grpc"} + recordProtocols = []string{rekeyRecordProtocolName} + keyLength = map[string]int{ + rekeyRecordProtocolName: 44, + } + altsRecordFuncs = map[string]conn.ALTSRecordFunc{ + // ALTS handshaker protocols. + rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) { + return conn.NewAES128GCMRekey(s, keyData) + }, + } + // control number of concurrent created (but not closed) handshakers. + mu sync.Mutex + concurrentHandshakes = int64(0) + // errDropped occurs when maxPendingHandshakes is reached. + errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached") +) + +func init() { + for protocol, f := range altsRecordFuncs { + if err := conn.RegisterProtocol(protocol, f); err != nil { + panic(err) + } + } +} + +func acquire(n int64) bool { + mu.Lock() + success := maxPendingHandshakes-concurrentHandshakes >= n + if success { + concurrentHandshakes += n + } + mu.Unlock() + return success +} + +func release(n int64) { + mu.Lock() + concurrentHandshakes -= n + if concurrentHandshakes < 0 { + mu.Unlock() + panic("bad release") + } + mu.Unlock() +} + +// ClientHandshakerOptions contains the client handshaker options that can +// provided by the caller. +type ClientHandshakerOptions struct { + // ClientIdentity is the handshaker client local identity. + ClientIdentity *altspb.Identity + // TargetName is the server service account name for secure name + // checking. + TargetName string + // TargetServiceAccounts contains a list of expected target service + // accounts. One of these accounts should match one of the accounts in + // the handshaker results. Otherwise, the handshake fails. + TargetServiceAccounts []string + // RPCVersions specifies the gRPC versions accepted by the client. + RPCVersions *altspb.RpcProtocolVersions +} + +// ServerHandshakerOptions contains the server handshaker options that can +// provided by the caller. +type ServerHandshakerOptions struct { + // RPCVersions specifies the gRPC versions accepted by the server. + RPCVersions *altspb.RpcProtocolVersions +} + +// DefaultClientHandshakerOptions returns the default client handshaker options. +func DefaultClientHandshakerOptions() *ClientHandshakerOptions { + return &ClientHandshakerOptions{} +} + +// DefaultServerHandshakerOptions returns the default client handshaker options. +func DefaultServerHandshakerOptions() *ServerHandshakerOptions { + return &ServerHandshakerOptions{} +} + +// TODO: add support for future local and remote endpoint in both client options +// and server options (server options struct does not exist now. When +// caller can provide endpoints, it should be created. + +// altsHandshaker is used to complete a ALTS handshaking between client and +// server. This handshaker talks to the ALTS handshaker service in the metadata +// server. +type altsHandshaker struct { + // RPC stream used to access the ALTS Handshaker service. + stream altspb.HandshakerService_DoHandshakeClient + // the connection to the peer. + conn net.Conn + // client handshake options. + clientOpts *ClientHandshakerOptions + // server handshake options. + serverOpts *ServerHandshakerOptions + // defines the side doing the handshake, client or server. + side core.Side +} + +// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC +// stub created using the passed conn and used to talk to the ALTS Handshaker +// service in the metadata server. +func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) { + stream, err := altspb.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false)) + if err != nil { + return nil, err + } + return &altsHandshaker{ + stream: stream, + conn: c, + clientOpts: opts, + side: core.ClientSide, + }, nil +} + +// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC +// stub created using the passed conn and used to talk to the ALTS Handshaker +// service in the metadata server. +func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) { + stream, err := altspb.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false)) + if err != nil { + return nil, err + } + return &altsHandshaker{ + stream: stream, + conn: c, + serverOpts: opts, + side: core.ServerSide, + }, nil +} + +// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once +// done, ClientHandshake returns a secure connection. +func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { + if !acquire(1) { + return nil, nil, errDropped + } + defer release(1) + + if h.side != core.ClientSide { + return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker") + } + + // Create target identities from service account list. + targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts)) + for _, account := range h.clientOpts.TargetServiceAccounts { + targetIdentities = append(targetIdentities, &altspb.Identity{ + IdentityOneof: &altspb.Identity_ServiceAccount{ + ServiceAccount: account, + }, + }) + } + req := &altspb.HandshakerReq{ + ReqOneof: &altspb.HandshakerReq_ClientStart{ + ClientStart: &altspb.StartClientHandshakeReq{ + HandshakeSecurityProtocol: hsProtocol, + ApplicationProtocols: appProtocols, + RecordProtocols: recordProtocols, + TargetIdentities: targetIdentities, + LocalIdentity: h.clientOpts.ClientIdentity, + TargetName: h.clientOpts.TargetName, + RpcVersions: h.clientOpts.RPCVersions, + }, + }, + } + + conn, result, err := h.doHandshake(req) + if err != nil { + return nil, nil, err + } + authInfo := authinfo.New(result) + return conn, authInfo, nil +} + +// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once +// done, ServerHandshake returns a secure connection. +func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { + if !acquire(1) { + return nil, nil, errDropped + } + defer release(1) + + if h.side != core.ServerSide { + return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker") + } + + p := make([]byte, frameLimit) + n, err := h.conn.Read(p) + if err != nil { + return nil, nil, err + } + + // Prepare server parameters. + // TODO: currently only ALTS parameters are provided. Might need to use + // more options in the future. + params := make(map[int32]*altspb.ServerHandshakeParameters) + params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{ + RecordProtocols: recordProtocols, + } + req := &altspb.HandshakerReq{ + ReqOneof: &altspb.HandshakerReq_ServerStart{ + ServerStart: &altspb.StartServerHandshakeReq{ + ApplicationProtocols: appProtocols, + HandshakeParameters: params, + InBytes: p[:n], + RpcVersions: h.serverOpts.RPCVersions, + }, + }, + } + + conn, result, err := h.doHandshake(req) + if err != nil { + return nil, nil, err + } + authInfo := authinfo.New(result) + return conn, authInfo, nil +} + +func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) { + resp, err := h.accessHandshakerService(req) + if err != nil { + return nil, nil, err + } + // Check of the returned status is an error. + if resp.GetStatus() != nil { + if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want { + return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details) + } + } + + var extra []byte + if req.GetServerStart() != nil { + extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():] + } + result, extra, err := h.processUntilDone(resp, extra) + if err != nil { + return nil, nil, err + } + // The handshaker returns a 128 bytes key. It should be truncated based + // on the returned record protocol. + keyLen, ok := keyLength[result.RecordProtocol] + if !ok { + return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol) + } + sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra) + if err != nil { + return nil, nil, err + } + return sc, result, nil +} + +func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) { + if err := h.stream.Send(req); err != nil { + return nil, err + } + resp, err := h.stream.Recv() + if err != nil { + return nil, err + } + return resp, nil +} + +// processUntilDone processes the handshake until the handshaker service returns +// the results. Handshaker service takes care of frame parsing, so we read +// whatever received from the network and send it to the handshaker service. +func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) { + for { + if len(resp.OutFrames) > 0 { + if _, err := h.conn.Write(resp.OutFrames); err != nil { + return nil, nil, err + } + } + if resp.Result != nil { + return resp.Result, extra, nil + } + buf := make([]byte, frameLimit) + n, err := h.conn.Read(buf) + if err != nil && err != io.EOF { + return nil, nil, err + } + // If there is nothing to send to the handshaker service, and + // nothing is received from the peer, then we are stuck. + // This covers the case when the peer is not responding. Note + // that handshaker service connection issues are caught in + // accessHandshakerService before we even get here. + if len(resp.OutFrames) == 0 && n == 0 { + return nil, nil, core.PeerNotRespondingError + } + // Append extra bytes from the previous interaction with the + // handshaker service with the current buffer read from conn. + p := append(extra, buf[:n]...) + resp, err = h.accessHandshakerService(&altspb.HandshakerReq{ + ReqOneof: &altspb.HandshakerReq_Next{ + Next: &altspb.NextHandshakeMessageReq{ + InBytes: p, + }, + }, + }) + if err != nil { + return nil, nil, err + } + // Set extra based on handshaker service response. + if n == 0 { + extra = nil + } else { + extra = buf[resp.GetBytesConsumed():n] + } + } +} + +// Close terminates the Handshaker. It should be called when the caller obtains +// the secure connection. +func (h *altsHandshaker) Close() { + h.stream.CloseSend() +} diff --git a/credentials/alts/core/handshaker/handshaker_test.go b/credentials/alts/core/handshaker/handshaker_test.go new file mode 100644 index 00000000..d01cd1d9 --- /dev/null +++ b/credentials/alts/core/handshaker/handshaker_test.go @@ -0,0 +1,261 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package handshaker + +import ( + "bytes" + "testing" + "time" + + "golang.org/x/net/context" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/credentials/alts/core" + altspb "google.golang.org/grpc/credentials/alts/core/proto" + "google.golang.org/grpc/credentials/alts/core/testutil" +) + +var ( + testAppProtocols = []string{"grpc"} + testRecordProtocol = rekeyRecordProtocolName + testKey = []byte{ + // 44 arbitrary bytes. + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, + 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b, + 0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, + } + testServiceAccount = "test_service_account" + testTargetServiceAccounts = []string{testServiceAccount} + testClientIdentity = &altspb.Identity{ + IdentityOneof: &altspb.Identity_Hostname{ + Hostname: "i_am_a_client", + }, + } +) + +// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object. +type testRPCStream struct { + grpc.ClientStream + t *testing.T + isClient bool + // The resp expected to be returned by Recv(). Make sure this is set to + // the content the test requires before Recv() is invoked. + recvBuf *altspb.HandshakerResp + // false if it is the first access to Handshaker service on Envelope. + first bool + // useful for testing concurrent calls. + delay time.Duration +} + +func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) { + resp := t.recvBuf + t.recvBuf = nil + return resp, nil +} + +func (t *testRPCStream) Send(req *altspb.HandshakerReq) error { + var resp *altspb.HandshakerResp + if !t.first { + // Generate the bytes to be returned by Recv() for the initial + // handshaking. + t.first = true + if t.isClient { + resp = &altspb.HandshakerResp{ + OutFrames: testutil.MakeFrame("ClientInit"), + // Simulate consuming ServerInit. + BytesConsumed: 14, + } + } else { + resp = &altspb.HandshakerResp{ + OutFrames: testutil.MakeFrame("ServerInit"), + // Simulate consuming ClientInit. + BytesConsumed: 14, + } + } + } else { + // Add delay to test concurrent calls. + close := stat.Update() + defer close() + time.Sleep(t.delay) + + // Generate the response to be returned by Recv() for the + // follow-up handshaking. + result := &altspb.HandshakerResult{ + RecordProtocol: testRecordProtocol, + KeyData: testKey, + } + resp = &altspb.HandshakerResp{ + Result: result, + // Simulate consuming ClientFinished or ServerFinished. + BytesConsumed: 18, + } + } + t.recvBuf = resp + return nil +} + +func (t *testRPCStream) CloseSend() error { + return nil +} + +var stat testutil.Stats + +func TestClientHandshake(t *testing.T) { + for _, testCase := range []struct { + delay time.Duration + numberOfHandshakes int + }{ + {0 * time.Millisecond, 1}, + {100 * time.Millisecond, 10 * maxPendingHandshakes}, + } { + errc := make(chan error) + stat.Reset() + for i := 0; i < testCase.numberOfHandshakes; i++ { + stream := &testRPCStream{ + t: t, + isClient: true, + } + // Preload the inbound frames. + f1 := testutil.MakeFrame("ServerInit") + f2 := testutil.MakeFrame("ServerFinished") + in := bytes.NewBuffer(f1) + in.Write(f2) + out := new(bytes.Buffer) + tc := testutil.NewTestConn(in, out) + chs := &altsHandshaker{ + stream: stream, + conn: tc, + clientOpts: &ClientHandshakerOptions{ + TargetServiceAccounts: testTargetServiceAccounts, + ClientIdentity: testClientIdentity, + }, + side: core.ClientSide, + } + go func() { + _, context, err := chs.ClientHandshake(context.Background()) + if err == nil && context == nil { + panic("expected non-nil ALTS context") + } + errc <- err + chs.Close() + }() + } + + // Ensure all errors are expected. + for i := 0; i < testCase.numberOfHandshakes; i++ { + if err := <-errc; err != nil && err != errDropped { + t.Errorf("ClientHandshake() = _, %v, want _, or %v", err, errDropped) + } + } + + // Ensure that there are no concurrent calls more than the limit. + if stat.MaxConcurrentCalls > maxPendingHandshakes { + t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) + } + } +} + +func TestServerHandshake(t *testing.T) { + for _, testCase := range []struct { + delay time.Duration + numberOfHandshakes int + }{ + {0 * time.Millisecond, 1}, + {100 * time.Millisecond, 10 * maxPendingHandshakes}, + } { + errc := make(chan error) + stat.Reset() + for i := 0; i < testCase.numberOfHandshakes; i++ { + stream := &testRPCStream{ + t: t, + isClient: false, + } + // Preload the inbound frames. + f1 := testutil.MakeFrame("ClientInit") + f2 := testutil.MakeFrame("ClientFinished") + in := bytes.NewBuffer(f1) + in.Write(f2) + out := new(bytes.Buffer) + tc := testutil.NewTestConn(in, out) + shs := &altsHandshaker{ + stream: stream, + conn: tc, + serverOpts: DefaultServerHandshakerOptions(), + side: core.ServerSide, + } + go func() { + _, context, err := shs.ServerHandshake(context.Background()) + if err == nil && context == nil { + panic("expected non-nil ALTS context") + } + errc <- err + shs.Close() + }() + } + + // Ensure all errors are expected. + for i := 0; i < testCase.numberOfHandshakes; i++ { + if err := <-errc; err != nil && err != errDropped { + t.Errorf("ServerHandshake() = _, %v, want _, or %v", err, errDropped) + } + } + + // Ensure that there are no concurrent calls more than the limit. + if stat.MaxConcurrentCalls > maxPendingHandshakes { + t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes) + } + } +} + +// testUnresponsiveRPCStream is used for testing the PeerNotResponding case. +type testUnresponsiveRPCStream struct { + grpc.ClientStream +} + +func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) { + return &altspb.HandshakerResp{}, nil +} + +func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error { + return nil +} + +func (t *testUnresponsiveRPCStream) CloseSend() error { + return nil +} + +func TestPeerNotResponding(t *testing.T) { + stream := &testUnresponsiveRPCStream{} + chs := &altsHandshaker{ + stream: stream, + conn: testutil.NewUnresponsiveTestConn(), + clientOpts: &ClientHandshakerOptions{ + TargetServiceAccounts: testTargetServiceAccounts, + ClientIdentity: testClientIdentity, + }, + side: core.ClientSide, + } + _, context, err := chs.ClientHandshake(context.Background()) + chs.Close() + if context != nil { + t.Error("expected non-nil ALTS context") + } + if got, want := err, core.PeerNotRespondingError; got != want { + t.Errorf("ClientHandshake() = %v, want %v", got, want) + } +} diff --git a/credentials/alts/core/handshaker/service/service.go b/credentials/alts/core/handshaker/service/service.go new file mode 100644 index 00000000..a839369f --- /dev/null +++ b/credentials/alts/core/handshaker/service/service.go @@ -0,0 +1,60 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package service manages connections between the VM application and the ALTS +// handshaker service. +package service + +import ( + "flag" + "sync" + + grpc "google.golang.org/grpc" +) + +var ( + // hsServiceAddr specifies the default ALTS handshaker service address in + // the hypervisor. + hsServiceAddr = flag.String("handshaker_service_address", "metadata.google.internal:8080", "ALTS handshaker gRPC service address") + // hsConn represents a connection to hypervisor handshaker service. + hsConn *grpc.ClientConn + mu sync.Mutex + // hsDialer will be reassigned in tests. + hsDialer = grpc.Dial +) + +type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) + +// Dial dials the handshake service in the hypervisor. If a connection has +// already been established, this function returns it. Otherwise, a new +// connection is created, +func Dial() (*grpc.ClientConn, error) { + mu.Lock() + defer mu.Unlock() + + if hsConn == nil { + // Create a new connection to the handshaker service. Note that + // this connection stays open until the application is closed. + var err error + hsConn, err = hsDialer(*hsServiceAddr, grpc.WithInsecure()) + if err != nil { + return nil, err + } + } + return hsConn, nil +} diff --git a/credentials/alts/core/handshaker/service/service_test.go b/credentials/alts/core/handshaker/service/service_test.go new file mode 100644 index 00000000..2f33def1 --- /dev/null +++ b/credentials/alts/core/handshaker/service/service_test.go @@ -0,0 +1,64 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package service + +import ( + "testing" + + grpc "google.golang.org/grpc" +) + +func TestDial(t *testing.T) { + defer func() func() { + temp := hsDialer + hsDialer = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return &grpc.ClientConn{}, nil + } + return func() { + hsDialer = temp + } + }() + + // Ensure that hsConn is nil at first. + hsConn = nil + + // First call to Dial, it should create set hsConn. + conn1, err := Dial() + if err != nil { + t.Fatalf("first call to Dial failed: %v", err) + } + if conn1 == nil { + t.Fatal("first call to Dial()=(nil, _), want not nil") + } + if got, want := hsConn, conn1; got != want { + t.Fatalf("hsConn=%v, want %v", got, want) + } + + // Second call to Dial() should return conn1 above. + conn2, err := Dial() + if err != nil { + t.Fatalf("second call to Dial() failed: %v", err) + } + if got, want := conn2, conn1; got != want { + t.Fatalf("second call to Dial()=(%v, _), want (%v,. _)", got, want) + } + if got, want := hsConn, conn1; got != want { + t.Fatalf("hsConn=%v, want %v", got, want) + } +} diff --git a/credentials/alts/core/proto/altscontext.pb.go b/credentials/alts/core/proto/altscontext.pb.go new file mode 100644 index 00000000..cb1dbab3 --- /dev/null +++ b/credentials/alts/core/proto/altscontext.pb.go @@ -0,0 +1,131 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: altscontext.proto + +/* +Package grpc_gcp is a generated protocol buffer package. + +It is generated from these files: + altscontext.proto + handshaker.proto + transport_security_common.proto + +It has these top-level messages: + AltsContext + Endpoint + Identity + StartClientHandshakeReq + ServerHandshakeParameters + StartServerHandshakeReq + NextHandshakeMessageReq + HandshakerReq + HandshakerResult + HandshakerStatus + HandshakerResp + RpcProtocolVersions +*/ +package grpc_gcp + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type AltsContext struct { + // The application protocol negotiated for this connection. + ApplicationProtocol string `protobuf:"bytes,1,opt,name=application_protocol,json=applicationProtocol" json:"application_protocol,omitempty"` + // The record protocol negotiated for this connection. + RecordProtocol string `protobuf:"bytes,2,opt,name=record_protocol,json=recordProtocol" json:"record_protocol,omitempty"` + // The security level of the created secure channel. + SecurityLevel SecurityLevel `protobuf:"varint,3,opt,name=security_level,json=securityLevel,enum=grpc.gcp.SecurityLevel" json:"security_level,omitempty"` + // The peer service account. + PeerServiceAccount string `protobuf:"bytes,4,opt,name=peer_service_account,json=peerServiceAccount" json:"peer_service_account,omitempty"` + // The local service account. + LocalServiceAccount string `protobuf:"bytes,5,opt,name=local_service_account,json=localServiceAccount" json:"local_service_account,omitempty"` + // The RPC protocol versions supported by the peer. + PeerRpcVersions *RpcProtocolVersions `protobuf:"bytes,6,opt,name=peer_rpc_versions,json=peerRpcVersions" json:"peer_rpc_versions,omitempty"` +} + +func (m *AltsContext) Reset() { *m = AltsContext{} } +func (m *AltsContext) String() string { return proto.CompactTextString(m) } +func (*AltsContext) ProtoMessage() {} +func (*AltsContext) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func (m *AltsContext) GetApplicationProtocol() string { + if m != nil { + return m.ApplicationProtocol + } + return "" +} + +func (m *AltsContext) GetRecordProtocol() string { + if m != nil { + return m.RecordProtocol + } + return "" +} + +func (m *AltsContext) GetSecurityLevel() SecurityLevel { + if m != nil { + return m.SecurityLevel + } + return SecurityLevel_SECURITY_NONE +} + +func (m *AltsContext) GetPeerServiceAccount() string { + if m != nil { + return m.PeerServiceAccount + } + return "" +} + +func (m *AltsContext) GetLocalServiceAccount() string { + if m != nil { + return m.LocalServiceAccount + } + return "" +} + +func (m *AltsContext) GetPeerRpcVersions() *RpcProtocolVersions { + if m != nil { + return m.PeerRpcVersions + } + return nil +} + +func init() { + proto.RegisterType((*AltsContext)(nil), "grpc.gcp.AltsContext") +} + +func init() { proto.RegisterFile("altscontext.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 280 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x64, 0x91, 0x41, 0x4b, 0x33, 0x31, + 0x10, 0x86, 0xd9, 0x7e, 0x9f, 0x45, 0x53, 0xdd, 0xd2, 0x58, 0x71, 0x11, 0xc4, 0xe2, 0xc5, 0x9e, + 0x16, 0x5d, 0xef, 0x42, 0xf5, 0x24, 0x78, 0x90, 0x2d, 0x78, 0x0d, 0x71, 0x0c, 0x25, 0x90, 0x66, + 0xc2, 0x24, 0x5d, 0xf4, 0xaf, 0xfa, 0x6b, 0x64, 0x93, 0xdd, 0xb6, 0xe8, 0x31, 0xf3, 0x3c, 0x6f, + 0x66, 0x26, 0x61, 0x13, 0x69, 0x82, 0x07, 0xb4, 0x41, 0x7d, 0x86, 0xd2, 0x11, 0x06, 0xe4, 0x87, + 0x2b, 0x72, 0x50, 0xae, 0xc0, 0x5d, 0x5c, 0x05, 0x92, 0xd6, 0x3b, 0xa4, 0x20, 0xbc, 0x82, 0x0d, + 0xe9, 0xf0, 0x25, 0x00, 0xd7, 0x6b, 0xb4, 0x49, 0xbd, 0xfe, 0x1e, 0xb0, 0xd1, 0xc2, 0x04, 0xff, + 0x94, 0x2e, 0xe0, 0x77, 0x6c, 0x2a, 0x9d, 0x33, 0x1a, 0x64, 0xd0, 0x68, 0x45, 0x94, 0x00, 0x4d, + 0x91, 0xcd, 0xb2, 0xf9, 0x51, 0x7d, 0xba, 0xc7, 0x5e, 0x3b, 0xc4, 0x6f, 0xd8, 0x98, 0x14, 0x20, + 0x7d, 0xec, 0xec, 0x41, 0xb4, 0xf3, 0x54, 0xde, 0x8a, 0x0f, 0x2c, 0xdf, 0x0e, 0x61, 0x54, 0xa3, + 0x4c, 0xf1, 0x6f, 0x96, 0xcd, 0xf3, 0xea, 0xbc, 0xec, 0xe7, 0x2d, 0x97, 0x1d, 0x7f, 0x69, 0x71, + 0x7d, 0xe2, 0xf7, 0x8f, 0xfc, 0x96, 0x4d, 0x9d, 0x52, 0x24, 0xbc, 0xa2, 0x46, 0x83, 0x12, 0x12, + 0x00, 0x37, 0x36, 0x14, 0xff, 0x63, 0x37, 0xde, 0xb2, 0x65, 0x42, 0x8b, 0x44, 0x78, 0xc5, 0xce, + 0x0c, 0x82, 0x34, 0x7f, 0x22, 0x07, 0x69, 0x9d, 0x08, 0x7f, 0x65, 0x9e, 0xd9, 0x24, 0x76, 0x21, + 0x07, 0xa2, 0x51, 0xe4, 0x35, 0x5a, 0x5f, 0x0c, 0x67, 0xd9, 0x7c, 0x54, 0x5d, 0xee, 0x06, 0xad, + 0x1d, 0xf4, 0x7b, 0xbd, 0x75, 0x52, 0x3d, 0x6e, 0x73, 0xb5, 0x83, 0xbe, 0xf0, 0x98, 0xb3, 0x63, + 0x8d, 0x29, 0xd3, 0x7e, 0xd2, 0xfb, 0x30, 0x3e, 0xd0, 0xfd, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, + 0x04, 0x64, 0x9c, 0x2f, 0xb3, 0x01, 0x00, 0x00, +} diff --git a/credentials/alts/core/proto/altscontext.proto b/credentials/alts/core/proto/altscontext.proto new file mode 100644 index 00000000..d195b37e --- /dev/null +++ b/credentials/alts/core/proto/altscontext.proto @@ -0,0 +1,41 @@ +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +import "transport_security_common.proto"; + +package grpc.gcp; + +option java_package = "io.grpc.alts"; + +message AltsContext { + // The application protocol negotiated for this connection. + string application_protocol = 1; + + // The record protocol negotiated for this connection. + string record_protocol = 2; + + // The security level of the created secure channel. + SecurityLevel security_level = 3; + + // The peer service account. + string peer_service_account = 4; + + // The local service account. + string local_service_account = 5; + + // The RPC protocol versions supported by the peer. + RpcProtocolVersions peer_rpc_versions = 6; +} diff --git a/credentials/alts/core/proto/handshaker.pb.go b/credentials/alts/core/proto/handshaker.pb.go new file mode 100644 index 00000000..8a2090ab --- /dev/null +++ b/credentials/alts/core/proto/handshaker.pb.go @@ -0,0 +1,933 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: handshaker.proto + +package grpc_gcp + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +type HandshakeProtocol int32 + +const ( + // Default value. + HandshakeProtocol_HANDSHAKE_PROTOCOL_UNSPECIFIED HandshakeProtocol = 0 + // TLS handshake protocol. + HandshakeProtocol_TLS HandshakeProtocol = 1 + // Application Layer Transport Security handshake protocol. + HandshakeProtocol_ALTS HandshakeProtocol = 2 +) + +var HandshakeProtocol_name = map[int32]string{ + 0: "HANDSHAKE_PROTOCOL_UNSPECIFIED", + 1: "TLS", + 2: "ALTS", +} +var HandshakeProtocol_value = map[string]int32{ + "HANDSHAKE_PROTOCOL_UNSPECIFIED": 0, + "TLS": 1, + "ALTS": 2, +} + +func (x HandshakeProtocol) String() string { + return proto.EnumName(HandshakeProtocol_name, int32(x)) +} +func (HandshakeProtocol) EnumDescriptor() ([]byte, []int) { return fileDescriptor1, []int{0} } + +type NetworkProtocol int32 + +const ( + NetworkProtocol_NETWORK_PROTOCOL_UNSPECIFIED NetworkProtocol = 0 + NetworkProtocol_TCP NetworkProtocol = 1 + NetworkProtocol_UDP NetworkProtocol = 2 +) + +var NetworkProtocol_name = map[int32]string{ + 0: "NETWORK_PROTOCOL_UNSPECIFIED", + 1: "TCP", + 2: "UDP", +} +var NetworkProtocol_value = map[string]int32{ + "NETWORK_PROTOCOL_UNSPECIFIED": 0, + "TCP": 1, + "UDP": 2, +} + +func (x NetworkProtocol) String() string { + return proto.EnumName(NetworkProtocol_name, int32(x)) +} +func (NetworkProtocol) EnumDescriptor() ([]byte, []int) { return fileDescriptor1, []int{1} } + +type Endpoint struct { + // IP address. It should contain an IPv4 or IPv6 string literal, e.g. + // "192.168.0.1" or "2001:db8::1". + IpAddress string `protobuf:"bytes,1,opt,name=ip_address,json=ipAddress" json:"ip_address,omitempty"` + // Port number. + Port int32 `protobuf:"varint,2,opt,name=port" json:"port,omitempty"` + // Network protocol (e.g., TCP, UDP) associated with this endpoint. + Protocol NetworkProtocol `protobuf:"varint,3,opt,name=protocol,enum=grpc.gcp.NetworkProtocol" json:"protocol,omitempty"` +} + +func (m *Endpoint) Reset() { *m = Endpoint{} } +func (m *Endpoint) String() string { return proto.CompactTextString(m) } +func (*Endpoint) ProtoMessage() {} +func (*Endpoint) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{0} } + +func (m *Endpoint) GetIpAddress() string { + if m != nil { + return m.IpAddress + } + return "" +} + +func (m *Endpoint) GetPort() int32 { + if m != nil { + return m.Port + } + return 0 +} + +func (m *Endpoint) GetProtocol() NetworkProtocol { + if m != nil { + return m.Protocol + } + return NetworkProtocol_NETWORK_PROTOCOL_UNSPECIFIED +} + +type Identity struct { + // Types that are valid to be assigned to IdentityOneof: + // *Identity_ServiceAccount + // *Identity_Hostname + IdentityOneof isIdentity_IdentityOneof `protobuf_oneof:"identity_oneof"` +} + +func (m *Identity) Reset() { *m = Identity{} } +func (m *Identity) String() string { return proto.CompactTextString(m) } +func (*Identity) ProtoMessage() {} +func (*Identity) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{1} } + +type isIdentity_IdentityOneof interface { + isIdentity_IdentityOneof() +} + +type Identity_ServiceAccount struct { + ServiceAccount string `protobuf:"bytes,1,opt,name=service_account,json=serviceAccount,oneof"` +} +type Identity_Hostname struct { + Hostname string `protobuf:"bytes,2,opt,name=hostname,oneof"` +} + +func (*Identity_ServiceAccount) isIdentity_IdentityOneof() {} +func (*Identity_Hostname) isIdentity_IdentityOneof() {} + +func (m *Identity) GetIdentityOneof() isIdentity_IdentityOneof { + if m != nil { + return m.IdentityOneof + } + return nil +} + +func (m *Identity) GetServiceAccount() string { + if x, ok := m.GetIdentityOneof().(*Identity_ServiceAccount); ok { + return x.ServiceAccount + } + return "" +} + +func (m *Identity) GetHostname() string { + if x, ok := m.GetIdentityOneof().(*Identity_Hostname); ok { + return x.Hostname + } + return "" +} + +// XXX_OneofFuncs is for the internal use of the proto package. +func (*Identity) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) { + return _Identity_OneofMarshaler, _Identity_OneofUnmarshaler, _Identity_OneofSizer, []interface{}{ + (*Identity_ServiceAccount)(nil), + (*Identity_Hostname)(nil), + } +} + +func _Identity_OneofMarshaler(msg proto.Message, b *proto.Buffer) error { + m := msg.(*Identity) + // identity_oneof + switch x := m.IdentityOneof.(type) { + case *Identity_ServiceAccount: + b.EncodeVarint(1<<3 | proto.WireBytes) + b.EncodeStringBytes(x.ServiceAccount) + case *Identity_Hostname: + b.EncodeVarint(2<<3 | proto.WireBytes) + b.EncodeStringBytes(x.Hostname) + case nil: + default: + return fmt.Errorf("Identity.IdentityOneof has unexpected type %T", x) + } + return nil +} + +func _Identity_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) { + m := msg.(*Identity) + switch tag { + case 1: // identity_oneof.service_account + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + x, err := b.DecodeStringBytes() + m.IdentityOneof = &Identity_ServiceAccount{x} + return true, err + case 2: // identity_oneof.hostname + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + x, err := b.DecodeStringBytes() + m.IdentityOneof = &Identity_Hostname{x} + return true, err + default: + return false, nil + } +} + +func _Identity_OneofSizer(msg proto.Message) (n int) { + m := msg.(*Identity) + // identity_oneof + switch x := m.IdentityOneof.(type) { + case *Identity_ServiceAccount: + n += proto.SizeVarint(1<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(len(x.ServiceAccount))) + n += len(x.ServiceAccount) + case *Identity_Hostname: + n += proto.SizeVarint(2<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(len(x.Hostname))) + n += len(x.Hostname) + case nil: + default: + panic(fmt.Sprintf("proto: unexpected type %T in oneof", x)) + } + return n +} + +type StartClientHandshakeReq struct { + // Handshake security protocol requested by the client. + HandshakeSecurityProtocol HandshakeProtocol `protobuf:"varint,1,opt,name=handshake_security_protocol,json=handshakeSecurityProtocol,enum=grpc.gcp.HandshakeProtocol" json:"handshake_security_protocol,omitempty"` + // The application protocols supported by the client, e.g., "h2" (for http2), + // "grpc". + ApplicationProtocols []string `protobuf:"bytes,2,rep,name=application_protocols,json=applicationProtocols" json:"application_protocols,omitempty"` + // The record protocols supported by the client, e.g., + // "ALTSRP_GCM_AES128". + RecordProtocols []string `protobuf:"bytes,3,rep,name=record_protocols,json=recordProtocols" json:"record_protocols,omitempty"` + // (Optional) Describes which server identities are acceptable by the client. + // If target identities are provided and none of them matches the peer + // identity of the server, handshake will fail. + TargetIdentities []*Identity `protobuf:"bytes,4,rep,name=target_identities,json=targetIdentities" json:"target_identities,omitempty"` + // (Optional) Application may specify a local identity. Otherwise, the + // handshaker chooses a default local identity. + LocalIdentity *Identity `protobuf:"bytes,5,opt,name=local_identity,json=localIdentity" json:"local_identity,omitempty"` + // (Optional) Local endpoint information of the connection to the server, + // such as local IP address, port number, and network protocol. + LocalEndpoint *Endpoint `protobuf:"bytes,6,opt,name=local_endpoint,json=localEndpoint" json:"local_endpoint,omitempty"` + // (Optional) Endpoint information of the remote server, such as IP address, + // port number, and network protocol. + RemoteEndpoint *Endpoint `protobuf:"bytes,7,opt,name=remote_endpoint,json=remoteEndpoint" json:"remote_endpoint,omitempty"` + // (Optional) If target name is provided, a secure naming check is performed + // to verify that the peer authenticated identity is indeed authorized to run + // the target name. + TargetName string `protobuf:"bytes,8,opt,name=target_name,json=targetName" json:"target_name,omitempty"` + // (Optional) RPC protocol versions supported by the client. + RpcVersions *RpcProtocolVersions `protobuf:"bytes,9,opt,name=rpc_versions,json=rpcVersions" json:"rpc_versions,omitempty"` +} + +func (m *StartClientHandshakeReq) Reset() { *m = StartClientHandshakeReq{} } +func (m *StartClientHandshakeReq) String() string { return proto.CompactTextString(m) } +func (*StartClientHandshakeReq) ProtoMessage() {} +func (*StartClientHandshakeReq) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{2} } + +func (m *StartClientHandshakeReq) GetHandshakeSecurityProtocol() HandshakeProtocol { + if m != nil { + return m.HandshakeSecurityProtocol + } + return HandshakeProtocol_HANDSHAKE_PROTOCOL_UNSPECIFIED +} + +func (m *StartClientHandshakeReq) GetApplicationProtocols() []string { + if m != nil { + return m.ApplicationProtocols + } + return nil +} + +func (m *StartClientHandshakeReq) GetRecordProtocols() []string { + if m != nil { + return m.RecordProtocols + } + return nil +} + +func (m *StartClientHandshakeReq) GetTargetIdentities() []*Identity { + if m != nil { + return m.TargetIdentities + } + return nil +} + +func (m *StartClientHandshakeReq) GetLocalIdentity() *Identity { + if m != nil { + return m.LocalIdentity + } + return nil +} + +func (m *StartClientHandshakeReq) GetLocalEndpoint() *Endpoint { + if m != nil { + return m.LocalEndpoint + } + return nil +} + +func (m *StartClientHandshakeReq) GetRemoteEndpoint() *Endpoint { + if m != nil { + return m.RemoteEndpoint + } + return nil +} + +func (m *StartClientHandshakeReq) GetTargetName() string { + if m != nil { + return m.TargetName + } + return "" +} + +func (m *StartClientHandshakeReq) GetRpcVersions() *RpcProtocolVersions { + if m != nil { + return m.RpcVersions + } + return nil +} + +type ServerHandshakeParameters struct { + // The record protocols supported by the server, e.g., + // "ALTSRP_GCM_AES128". + RecordProtocols []string `protobuf:"bytes,1,rep,name=record_protocols,json=recordProtocols" json:"record_protocols,omitempty"` + // (Optional) A list of local identities supported by the server, if + // specified. Otherwise, the handshaker chooses a default local identity. + LocalIdentities []*Identity `protobuf:"bytes,2,rep,name=local_identities,json=localIdentities" json:"local_identities,omitempty"` +} + +func (m *ServerHandshakeParameters) Reset() { *m = ServerHandshakeParameters{} } +func (m *ServerHandshakeParameters) String() string { return proto.CompactTextString(m) } +func (*ServerHandshakeParameters) ProtoMessage() {} +func (*ServerHandshakeParameters) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{3} } + +func (m *ServerHandshakeParameters) GetRecordProtocols() []string { + if m != nil { + return m.RecordProtocols + } + return nil +} + +func (m *ServerHandshakeParameters) GetLocalIdentities() []*Identity { + if m != nil { + return m.LocalIdentities + } + return nil +} + +type StartServerHandshakeReq struct { + // The application protocols supported by the server, e.g., "h2" (for http2), + // "grpc". + ApplicationProtocols []string `protobuf:"bytes,1,rep,name=application_protocols,json=applicationProtocols" json:"application_protocols,omitempty"` + // Handshake parameters (record protocols and local identities supported by + // the server) mapped by the handshake protocol. Each handshake security + // protocol (e.g., TLS or ALTS) has its own set of record protocols and local + // identities. Since protobuf does not support enum as key to the map, the key + // to handshake_parameters is the integer value of HandshakeProtocol enum. + HandshakeParameters map[int32]*ServerHandshakeParameters `protobuf:"bytes,2,rep,name=handshake_parameters,json=handshakeParameters" json:"handshake_parameters,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Bytes in out_frames returned from the peer's HandshakerResp. It is possible + // that the peer's out_frames are split into multiple HandshakReq messages. + InBytes []byte `protobuf:"bytes,3,opt,name=in_bytes,json=inBytes,proto3" json:"in_bytes,omitempty"` + // (Optional) Local endpoint information of the connection to the client, + // such as local IP address, port number, and network protocol. + LocalEndpoint *Endpoint `protobuf:"bytes,4,opt,name=local_endpoint,json=localEndpoint" json:"local_endpoint,omitempty"` + // (Optional) Endpoint information of the remote client, such as IP address, + // port number, and network protocol. + RemoteEndpoint *Endpoint `protobuf:"bytes,5,opt,name=remote_endpoint,json=remoteEndpoint" json:"remote_endpoint,omitempty"` + // (Optional) RPC protocol versions supported by the server. + RpcVersions *RpcProtocolVersions `protobuf:"bytes,6,opt,name=rpc_versions,json=rpcVersions" json:"rpc_versions,omitempty"` +} + +func (m *StartServerHandshakeReq) Reset() { *m = StartServerHandshakeReq{} } +func (m *StartServerHandshakeReq) String() string { return proto.CompactTextString(m) } +func (*StartServerHandshakeReq) ProtoMessage() {} +func (*StartServerHandshakeReq) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{4} } + +func (m *StartServerHandshakeReq) GetApplicationProtocols() []string { + if m != nil { + return m.ApplicationProtocols + } + return nil +} + +func (m *StartServerHandshakeReq) GetHandshakeParameters() map[int32]*ServerHandshakeParameters { + if m != nil { + return m.HandshakeParameters + } + return nil +} + +func (m *StartServerHandshakeReq) GetInBytes() []byte { + if m != nil { + return m.InBytes + } + return nil +} + +func (m *StartServerHandshakeReq) GetLocalEndpoint() *Endpoint { + if m != nil { + return m.LocalEndpoint + } + return nil +} + +func (m *StartServerHandshakeReq) GetRemoteEndpoint() *Endpoint { + if m != nil { + return m.RemoteEndpoint + } + return nil +} + +func (m *StartServerHandshakeReq) GetRpcVersions() *RpcProtocolVersions { + if m != nil { + return m.RpcVersions + } + return nil +} + +type NextHandshakeMessageReq struct { + // Bytes in out_frames returned from the peer's HandshakerResp. It is possible + // that the peer's out_frames are split into multiple NextHandshakerMessageReq + // messages. + InBytes []byte `protobuf:"bytes,1,opt,name=in_bytes,json=inBytes,proto3" json:"in_bytes,omitempty"` +} + +func (m *NextHandshakeMessageReq) Reset() { *m = NextHandshakeMessageReq{} } +func (m *NextHandshakeMessageReq) String() string { return proto.CompactTextString(m) } +func (*NextHandshakeMessageReq) ProtoMessage() {} +func (*NextHandshakeMessageReq) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{5} } + +func (m *NextHandshakeMessageReq) GetInBytes() []byte { + if m != nil { + return m.InBytes + } + return nil +} + +type HandshakerReq struct { + // Types that are valid to be assigned to ReqOneof: + // *HandshakerReq_ClientStart + // *HandshakerReq_ServerStart + // *HandshakerReq_Next + ReqOneof isHandshakerReq_ReqOneof `protobuf_oneof:"req_oneof"` +} + +func (m *HandshakerReq) Reset() { *m = HandshakerReq{} } +func (m *HandshakerReq) String() string { return proto.CompactTextString(m) } +func (*HandshakerReq) ProtoMessage() {} +func (*HandshakerReq) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{6} } + +type isHandshakerReq_ReqOneof interface { + isHandshakerReq_ReqOneof() +} + +type HandshakerReq_ClientStart struct { + ClientStart *StartClientHandshakeReq `protobuf:"bytes,1,opt,name=client_start,json=clientStart,oneof"` +} +type HandshakerReq_ServerStart struct { + ServerStart *StartServerHandshakeReq `protobuf:"bytes,2,opt,name=server_start,json=serverStart,oneof"` +} +type HandshakerReq_Next struct { + Next *NextHandshakeMessageReq `protobuf:"bytes,3,opt,name=next,oneof"` +} + +func (*HandshakerReq_ClientStart) isHandshakerReq_ReqOneof() {} +func (*HandshakerReq_ServerStart) isHandshakerReq_ReqOneof() {} +func (*HandshakerReq_Next) isHandshakerReq_ReqOneof() {} + +func (m *HandshakerReq) GetReqOneof() isHandshakerReq_ReqOneof { + if m != nil { + return m.ReqOneof + } + return nil +} + +func (m *HandshakerReq) GetClientStart() *StartClientHandshakeReq { + if x, ok := m.GetReqOneof().(*HandshakerReq_ClientStart); ok { + return x.ClientStart + } + return nil +} + +func (m *HandshakerReq) GetServerStart() *StartServerHandshakeReq { + if x, ok := m.GetReqOneof().(*HandshakerReq_ServerStart); ok { + return x.ServerStart + } + return nil +} + +func (m *HandshakerReq) GetNext() *NextHandshakeMessageReq { + if x, ok := m.GetReqOneof().(*HandshakerReq_Next); ok { + return x.Next + } + return nil +} + +// XXX_OneofFuncs is for the internal use of the proto package. +func (*HandshakerReq) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) { + return _HandshakerReq_OneofMarshaler, _HandshakerReq_OneofUnmarshaler, _HandshakerReq_OneofSizer, []interface{}{ + (*HandshakerReq_ClientStart)(nil), + (*HandshakerReq_ServerStart)(nil), + (*HandshakerReq_Next)(nil), + } +} + +func _HandshakerReq_OneofMarshaler(msg proto.Message, b *proto.Buffer) error { + m := msg.(*HandshakerReq) + // req_oneof + switch x := m.ReqOneof.(type) { + case *HandshakerReq_ClientStart: + b.EncodeVarint(1<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.ClientStart); err != nil { + return err + } + case *HandshakerReq_ServerStart: + b.EncodeVarint(2<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.ServerStart); err != nil { + return err + } + case *HandshakerReq_Next: + b.EncodeVarint(3<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.Next); err != nil { + return err + } + case nil: + default: + return fmt.Errorf("HandshakerReq.ReqOneof has unexpected type %T", x) + } + return nil +} + +func _HandshakerReq_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) { + m := msg.(*HandshakerReq) + switch tag { + case 1: // req_oneof.client_start + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(StartClientHandshakeReq) + err := b.DecodeMessage(msg) + m.ReqOneof = &HandshakerReq_ClientStart{msg} + return true, err + case 2: // req_oneof.server_start + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(StartServerHandshakeReq) + err := b.DecodeMessage(msg) + m.ReqOneof = &HandshakerReq_ServerStart{msg} + return true, err + case 3: // req_oneof.next + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(NextHandshakeMessageReq) + err := b.DecodeMessage(msg) + m.ReqOneof = &HandshakerReq_Next{msg} + return true, err + default: + return false, nil + } +} + +func _HandshakerReq_OneofSizer(msg proto.Message) (n int) { + m := msg.(*HandshakerReq) + // req_oneof + switch x := m.ReqOneof.(type) { + case *HandshakerReq_ClientStart: + s := proto.Size(x.ClientStart) + n += proto.SizeVarint(1<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case *HandshakerReq_ServerStart: + s := proto.Size(x.ServerStart) + n += proto.SizeVarint(2<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case *HandshakerReq_Next: + s := proto.Size(x.Next) + n += proto.SizeVarint(3<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case nil: + default: + panic(fmt.Sprintf("proto: unexpected type %T in oneof", x)) + } + return n +} + +type HandshakerResult struct { + // The application protocol negotiated for this connection. + ApplicationProtocol string `protobuf:"bytes,1,opt,name=application_protocol,json=applicationProtocol" json:"application_protocol,omitempty"` + // The record protocol negotiated for this connection. + RecordProtocol string `protobuf:"bytes,2,opt,name=record_protocol,json=recordProtocol" json:"record_protocol,omitempty"` + // Cryptographic key data. The key data may be more than the key length + // required for the record protocol, thus the client of the handshaker + // service needs to truncate the key data into the right key length. + KeyData []byte `protobuf:"bytes,3,opt,name=key_data,json=keyData,proto3" json:"key_data,omitempty"` + // The authenticated identity of the peer. + PeerIdentity *Identity `protobuf:"bytes,4,opt,name=peer_identity,json=peerIdentity" json:"peer_identity,omitempty"` + // The local identity used in the handshake. + LocalIdentity *Identity `protobuf:"bytes,5,opt,name=local_identity,json=localIdentity" json:"local_identity,omitempty"` + // Indicate whether the handshaker service client should keep the channel + // between the handshaker service open, e.g., in order to handle + // post-handshake messages in the future. + KeepChannelOpen bool `protobuf:"varint,6,opt,name=keep_channel_open,json=keepChannelOpen" json:"keep_channel_open,omitempty"` + // The RPC protocol versions supported by the peer. + PeerRpcVersions *RpcProtocolVersions `protobuf:"bytes,7,opt,name=peer_rpc_versions,json=peerRpcVersions" json:"peer_rpc_versions,omitempty"` +} + +func (m *HandshakerResult) Reset() { *m = HandshakerResult{} } +func (m *HandshakerResult) String() string { return proto.CompactTextString(m) } +func (*HandshakerResult) ProtoMessage() {} +func (*HandshakerResult) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{7} } + +func (m *HandshakerResult) GetApplicationProtocol() string { + if m != nil { + return m.ApplicationProtocol + } + return "" +} + +func (m *HandshakerResult) GetRecordProtocol() string { + if m != nil { + return m.RecordProtocol + } + return "" +} + +func (m *HandshakerResult) GetKeyData() []byte { + if m != nil { + return m.KeyData + } + return nil +} + +func (m *HandshakerResult) GetPeerIdentity() *Identity { + if m != nil { + return m.PeerIdentity + } + return nil +} + +func (m *HandshakerResult) GetLocalIdentity() *Identity { + if m != nil { + return m.LocalIdentity + } + return nil +} + +func (m *HandshakerResult) GetKeepChannelOpen() bool { + if m != nil { + return m.KeepChannelOpen + } + return false +} + +func (m *HandshakerResult) GetPeerRpcVersions() *RpcProtocolVersions { + if m != nil { + return m.PeerRpcVersions + } + return nil +} + +type HandshakerStatus struct { + // The status code. This could be the gRPC status code. + Code uint32 `protobuf:"varint,1,opt,name=code" json:"code,omitempty"` + // The status details. + Details string `protobuf:"bytes,2,opt,name=details" json:"details,omitempty"` +} + +func (m *HandshakerStatus) Reset() { *m = HandshakerStatus{} } +func (m *HandshakerStatus) String() string { return proto.CompactTextString(m) } +func (*HandshakerStatus) ProtoMessage() {} +func (*HandshakerStatus) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{8} } + +func (m *HandshakerStatus) GetCode() uint32 { + if m != nil { + return m.Code + } + return 0 +} + +func (m *HandshakerStatus) GetDetails() string { + if m != nil { + return m.Details + } + return "" +} + +type HandshakerResp struct { + // Frames to be given to the peer for the NextHandshakeMessageReq. May be + // empty if no out_frames have to be sent to the peer or if in_bytes in the + // HandshakerReq are incomplete. All the non-empty out frames must be sent to + // the peer even if the handshaker status is not OK as these frames may + // contain the alert frames. + OutFrames []byte `protobuf:"bytes,1,opt,name=out_frames,json=outFrames,proto3" json:"out_frames,omitempty"` + // Number of bytes in the in_bytes consumed by the handshaker. It is possible + // that part of in_bytes in HandshakerReq was unrelated to the handshake + // process. + BytesConsumed uint32 `protobuf:"varint,2,opt,name=bytes_consumed,json=bytesConsumed" json:"bytes_consumed,omitempty"` + // This is set iff the handshake was successful. out_frames may still be set + // to frames that needs to be forwarded to the peer. + Result *HandshakerResult `protobuf:"bytes,3,opt,name=result" json:"result,omitempty"` + // Status of the handshaker. + Status *HandshakerStatus `protobuf:"bytes,4,opt,name=status" json:"status,omitempty"` +} + +func (m *HandshakerResp) Reset() { *m = HandshakerResp{} } +func (m *HandshakerResp) String() string { return proto.CompactTextString(m) } +func (*HandshakerResp) ProtoMessage() {} +func (*HandshakerResp) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{9} } + +func (m *HandshakerResp) GetOutFrames() []byte { + if m != nil { + return m.OutFrames + } + return nil +} + +func (m *HandshakerResp) GetBytesConsumed() uint32 { + if m != nil { + return m.BytesConsumed + } + return 0 +} + +func (m *HandshakerResp) GetResult() *HandshakerResult { + if m != nil { + return m.Result + } + return nil +} + +func (m *HandshakerResp) GetStatus() *HandshakerStatus { + if m != nil { + return m.Status + } + return nil +} + +func init() { + proto.RegisterType((*Endpoint)(nil), "grpc.gcp.Endpoint") + proto.RegisterType((*Identity)(nil), "grpc.gcp.Identity") + proto.RegisterType((*StartClientHandshakeReq)(nil), "grpc.gcp.StartClientHandshakeReq") + proto.RegisterType((*ServerHandshakeParameters)(nil), "grpc.gcp.ServerHandshakeParameters") + proto.RegisterType((*StartServerHandshakeReq)(nil), "grpc.gcp.StartServerHandshakeReq") + proto.RegisterType((*NextHandshakeMessageReq)(nil), "grpc.gcp.NextHandshakeMessageReq") + proto.RegisterType((*HandshakerReq)(nil), "grpc.gcp.HandshakerReq") + proto.RegisterType((*HandshakerResult)(nil), "grpc.gcp.HandshakerResult") + proto.RegisterType((*HandshakerStatus)(nil), "grpc.gcp.HandshakerStatus") + proto.RegisterType((*HandshakerResp)(nil), "grpc.gcp.HandshakerResp") + proto.RegisterEnum("grpc.gcp.HandshakeProtocol", HandshakeProtocol_name, HandshakeProtocol_value) + proto.RegisterEnum("grpc.gcp.NetworkProtocol", NetworkProtocol_name, NetworkProtocol_value) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for HandshakerService service + +type HandshakerServiceClient interface { + // Accepts a stream of handshaker request, returning a stream of handshaker + // response. + DoHandshake(ctx context.Context, opts ...grpc.CallOption) (HandshakerService_DoHandshakeClient, error) +} + +type handshakerServiceClient struct { + cc *grpc.ClientConn +} + +func NewHandshakerServiceClient(cc *grpc.ClientConn) HandshakerServiceClient { + return &handshakerServiceClient{cc} +} + +func (c *handshakerServiceClient) DoHandshake(ctx context.Context, opts ...grpc.CallOption) (HandshakerService_DoHandshakeClient, error) { + stream, err := grpc.NewClientStream(ctx, &_HandshakerService_serviceDesc.Streams[0], c.cc, "/grpc.gcp.HandshakerService/DoHandshake", opts...) + if err != nil { + return nil, err + } + x := &handshakerServiceDoHandshakeClient{stream} + return x, nil +} + +type HandshakerService_DoHandshakeClient interface { + Send(*HandshakerReq) error + Recv() (*HandshakerResp, error) + grpc.ClientStream +} + +type handshakerServiceDoHandshakeClient struct { + grpc.ClientStream +} + +func (x *handshakerServiceDoHandshakeClient) Send(m *HandshakerReq) error { + return x.ClientStream.SendMsg(m) +} + +func (x *handshakerServiceDoHandshakeClient) Recv() (*HandshakerResp, error) { + m := new(HandshakerResp) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for HandshakerService service + +type HandshakerServiceServer interface { + // Accepts a stream of handshaker request, returning a stream of handshaker + // response. + DoHandshake(HandshakerService_DoHandshakeServer) error +} + +func RegisterHandshakerServiceServer(s *grpc.Server, srv HandshakerServiceServer) { + s.RegisterService(&_HandshakerService_serviceDesc, srv) +} + +func _HandshakerService_DoHandshake_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(HandshakerServiceServer).DoHandshake(&handshakerServiceDoHandshakeServer{stream}) +} + +type HandshakerService_DoHandshakeServer interface { + Send(*HandshakerResp) error + Recv() (*HandshakerReq, error) + grpc.ServerStream +} + +type handshakerServiceDoHandshakeServer struct { + grpc.ServerStream +} + +func (x *handshakerServiceDoHandshakeServer) Send(m *HandshakerResp) error { + return x.ServerStream.SendMsg(m) +} + +func (x *handshakerServiceDoHandshakeServer) Recv() (*HandshakerReq, error) { + m := new(HandshakerReq) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _HandshakerService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.gcp.HandshakerService", + HandlerType: (*HandshakerServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "DoHandshake", + Handler: _HandshakerService_DoHandshake_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "handshaker.proto", +} + +func init() { proto.RegisterFile("handshaker.proto", fileDescriptor1) } + +var fileDescriptor1 = []byte{ + // 1066 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x56, 0xdd, 0x6e, 0x1a, 0x47, + 0x14, 0xf6, 0x02, 0xb6, 0xe1, 0x60, 0x60, 0x3d, 0x71, 0x65, 0xec, 0x24, 0x0d, 0xa5, 0xaa, 0x4a, + 0x7c, 0x61, 0xb5, 0xa4, 0x55, 0x9a, 0x54, 0x55, 0x63, 0x63, 0x2c, 0xdc, 0xb8, 0xd8, 0x1a, 0x9c, + 0xf6, 0x22, 0x17, 0xab, 0xc9, 0x72, 0x62, 0xaf, 0x80, 0x99, 0xf5, 0xcc, 0xe0, 0x86, 0x07, 0xe8, + 0xe3, 0xf4, 0x15, 0xfa, 0x36, 0x7d, 0x83, 0xde, 0xb7, 0xda, 0xd9, 0x3f, 0x8c, 0x97, 0x28, 0x51, + 0xee, 0x76, 0xcf, 0x7e, 0xdf, 0xd9, 0x39, 0xdf, 0xf9, 0xe6, 0xcc, 0x80, 0x7d, 0xc5, 0xf8, 0x50, + 0x5d, 0xb1, 0x11, 0xca, 0x7d, 0x5f, 0x0a, 0x2d, 0x48, 0xf1, 0x52, 0xfa, 0xee, 0xfe, 0xa5, 0xeb, + 0xef, 0x3e, 0xd2, 0x92, 0x71, 0xe5, 0x0b, 0xa9, 0x1d, 0x85, 0xee, 0x54, 0x7a, 0x7a, 0xe6, 0xb8, + 0x62, 0x32, 0x11, 0x3c, 0x84, 0x36, 0x35, 0x14, 0xbb, 0x7c, 0xe8, 0x0b, 0x8f, 0x6b, 0xf2, 0x10, + 0xc0, 0xf3, 0x1d, 0x36, 0x1c, 0x4a, 0x54, 0xaa, 0x6e, 0x35, 0xac, 0x56, 0x89, 0x96, 0x3c, 0xff, + 0x20, 0x0c, 0x10, 0x02, 0x85, 0x20, 0x51, 0x3d, 0xd7, 0xb0, 0x5a, 0xab, 0xd4, 0x3c, 0x93, 0xef, + 0xa1, 0x68, 0xf2, 0xb8, 0x62, 0x5c, 0xcf, 0x37, 0xac, 0x56, 0xb5, 0xbd, 0xb3, 0x1f, 0xff, 0x7c, + 0xbf, 0x8f, 0xfa, 0x0f, 0x21, 0x47, 0xe7, 0x11, 0x80, 0x26, 0xd0, 0x26, 0x42, 0xf1, 0x64, 0x88, + 0x5c, 0x7b, 0x7a, 0x46, 0x1e, 0x43, 0x4d, 0xa1, 0xbc, 0xf1, 0x5c, 0x74, 0x98, 0xeb, 0x8a, 0x29, + 0xd7, 0xe1, 0xaf, 0x7b, 0x2b, 0xb4, 0x1a, 0x7d, 0x38, 0x08, 0xe3, 0xe4, 0x01, 0x14, 0xaf, 0x84, + 0xd2, 0x9c, 0x4d, 0xd0, 0xac, 0x22, 0xc0, 0x24, 0x91, 0x43, 0x1b, 0xaa, 0x5e, 0x94, 0xd4, 0x11, + 0x1c, 0xc5, 0xdb, 0xe6, 0x5f, 0x05, 0xd8, 0x1e, 0x68, 0x26, 0x75, 0x67, 0xec, 0x21, 0xd7, 0xbd, + 0x58, 0x27, 0x8a, 0xd7, 0xe4, 0x35, 0xdc, 0x4f, 0x74, 0x4b, 0xb5, 0x49, 0x8a, 0xb1, 0x4c, 0x31, + 0xf7, 0xd3, 0x62, 0x12, 0x72, 0x52, 0xce, 0x4e, 0xc2, 0x1f, 0x44, 0xf4, 0xf8, 0x13, 0x79, 0x02, + 0x9f, 0x31, 0xdf, 0x1f, 0x7b, 0x2e, 0xd3, 0x9e, 0xe0, 0x49, 0x56, 0x55, 0xcf, 0x35, 0xf2, 0xad, + 0x12, 0xdd, 0x9a, 0xfb, 0x18, 0x73, 0x14, 0x79, 0x0c, 0xb6, 0x44, 0x57, 0xc8, 0xe1, 0x1c, 0x3e, + 0x6f, 0xf0, 0xb5, 0x30, 0x9e, 0x42, 0x7f, 0x86, 0x4d, 0xcd, 0xe4, 0x25, 0x6a, 0x27, 0xaa, 0xd8, + 0x43, 0x55, 0x2f, 0x34, 0xf2, 0xad, 0x72, 0x9b, 0xa4, 0x4b, 0x8e, 0x25, 0xa6, 0x76, 0x08, 0x3e, + 0x49, 0xb0, 0xe4, 0x19, 0x54, 0xc7, 0xc2, 0x65, 0xe3, 0x98, 0x3f, 0xab, 0xaf, 0x36, 0xac, 0x25, + 0xec, 0x8a, 0x41, 0x26, 0xfd, 0x4a, 0xa8, 0x18, 0xf9, 0xa6, 0xbe, 0xb6, 0x48, 0x8d, 0x1d, 0x15, + 0x51, 0x13, 0x83, 0xfd, 0x08, 0x35, 0x89, 0x13, 0xa1, 0x31, 0xe5, 0xae, 0x2f, 0xe5, 0x56, 0x43, + 0x68, 0x42, 0x7e, 0x04, 0xe5, 0xa8, 0x66, 0xd3, 0xff, 0xa2, 0xb1, 0x27, 0x84, 0xa1, 0x3e, 0x9b, + 0x20, 0x79, 0x01, 0x1b, 0xd2, 0x77, 0x9d, 0x1b, 0x94, 0xca, 0x13, 0x5c, 0xd5, 0x4b, 0x26, 0xf5, + 0xc3, 0x34, 0x35, 0xf5, 0xdd, 0x58, 0xc2, 0xdf, 0x22, 0x10, 0x2d, 0x4b, 0xdf, 0x8d, 0x5f, 0x9a, + 0x7f, 0x5a, 0xb0, 0x33, 0x40, 0x79, 0x83, 0x32, 0xed, 0x36, 0x93, 0x6c, 0x82, 0x1a, 0x65, 0x76, + 0x7f, 0xac, 0xec, 0xfe, 0xfc, 0x04, 0xf6, 0x2d, 0x79, 0x83, 0xf6, 0xe4, 0x96, 0xb6, 0xa7, 0x36, + 0x2f, 0xb0, 0x87, 0xaa, 0xf9, 0x5f, 0x3e, 0xf2, 0xed, 0xc2, 0x62, 0x02, 0xdf, 0x2e, 0xb5, 0x96, + 0xf5, 0x1e, 0x6b, 0x4d, 0x60, 0x2b, 0x35, 0xbb, 0x9f, 0x94, 0x14, 0xad, 0xe9, 0x79, 0xba, 0xa6, + 0x25, 0x7f, 0xdd, 0xcf, 0xd0, 0xa3, 0xcb, 0xb5, 0x9c, 0xd1, 0x7b, 0x57, 0x19, 0x4a, 0xed, 0x40, + 0xd1, 0xe3, 0xce, 0x9b, 0x99, 0x46, 0x65, 0xa6, 0xc2, 0x06, 0x5d, 0xf7, 0xf8, 0x61, 0xf0, 0x9a, + 0xe1, 0x9e, 0xc2, 0x27, 0xb8, 0x67, 0xf5, 0x83, 0xdd, 0xb3, 0x68, 0x8e, 0xb5, 0x8f, 0x35, 0xc7, + 0xee, 0x08, 0xea, 0xcb, 0x54, 0x20, 0x36, 0xe4, 0x47, 0x38, 0x33, 0x43, 0x63, 0x95, 0x06, 0x8f, + 0xe4, 0x19, 0xac, 0xde, 0xb0, 0xf1, 0x34, 0x9c, 0x53, 0xe5, 0xf6, 0x97, 0x73, 0x12, 0x2f, 0x33, + 0x18, 0x0d, 0x19, 0xcf, 0x73, 0x3f, 0x58, 0xcd, 0xef, 0x60, 0xbb, 0x8f, 0xef, 0xd2, 0x89, 0xf5, + 0x2b, 0x2a, 0xc5, 0x2e, 0x8d, 0x01, 0xe6, 0xc5, 0xb5, 0x6e, 0x89, 0xdb, 0xfc, 0xc7, 0x82, 0x4a, + 0x42, 0x91, 0x01, 0xf8, 0x18, 0x36, 0x5c, 0x33, 0xfb, 0x1c, 0x15, 0x74, 0xd6, 0x10, 0xca, 0xed, + 0x2f, 0x16, 0x1a, 0x7e, 0x77, 0x3c, 0xf6, 0x56, 0x68, 0x39, 0x24, 0x1a, 0x40, 0x90, 0x47, 0x99, + 0x75, 0x47, 0x79, 0x72, 0x99, 0x79, 0xee, 0x1a, 0x27, 0xc8, 0x13, 0x12, 0xc3, 0x3c, 0x4f, 0xa1, + 0xc0, 0xf1, 0x9d, 0x36, 0xae, 0xb8, 0xc5, 0x5f, 0x52, 0x6d, 0x6f, 0x85, 0x1a, 0xc2, 0x61, 0x19, + 0x4a, 0x12, 0xaf, 0xa3, 0xb9, 0xfe, 0x6f, 0x0e, 0xec, 0xf9, 0x3a, 0xd5, 0x74, 0xac, 0xc9, 0xb7, + 0xb0, 0x95, 0xb5, 0x31, 0xa2, 0x73, 0xec, 0x5e, 0xc6, 0xbe, 0x20, 0x5f, 0x43, 0x6d, 0x61, 0x47, + 0x87, 0xc7, 0x4a, 0xe0, 0x9e, 0xf9, 0x0d, 0x1d, 0x68, 0x3e, 0xc2, 0x99, 0x33, 0x64, 0x9a, 0xc5, + 0x86, 0x1e, 0xe1, 0xec, 0x88, 0x69, 0x46, 0x9e, 0x42, 0xc5, 0x47, 0x94, 0xe9, 0x20, 0x2d, 0x2c, + 0x1d, 0xa4, 0x1b, 0x01, 0xf0, 0xee, 0x1c, 0xfd, 0xf8, 0x11, 0xbc, 0x07, 0x9b, 0x23, 0x44, 0xdf, + 0x71, 0xaf, 0x18, 0xe7, 0x38, 0x76, 0x84, 0x8f, 0xdc, 0x38, 0xba, 0x48, 0x6b, 0xc1, 0x87, 0x4e, + 0x18, 0x3f, 0xf3, 0x91, 0x93, 0x13, 0xd8, 0x34, 0xeb, 0xbb, 0xe5, 0xfe, 0xf5, 0x0f, 0x71, 0x7f, + 0x2d, 0xe0, 0xd1, 0xb9, 0xf1, 0xf8, 0x62, 0x5e, 0xf5, 0x81, 0x66, 0x7a, 0x6a, 0x2e, 0x05, 0xae, + 0x18, 0xa2, 0x51, 0xb9, 0x42, 0xcd, 0x33, 0xa9, 0xc3, 0xfa, 0x10, 0x35, 0xf3, 0xcc, 0x79, 0x17, + 0xc8, 0x19, 0xbf, 0x36, 0xff, 0xb6, 0xa0, 0x7a, 0xab, 0x71, 0x7e, 0x70, 0xe9, 0x10, 0x53, 0xed, + 0xbc, 0x0d, 0x76, 0x41, 0x6c, 0xe8, 0x92, 0x98, 0xea, 0x63, 0x13, 0x20, 0x5f, 0x41, 0xd5, 0x58, + 0xdd, 0x71, 0x05, 0x57, 0xd3, 0x09, 0x0e, 0x4d, 0xca, 0x0a, 0xad, 0x98, 0x68, 0x27, 0x0a, 0x92, + 0x36, 0xac, 0x49, 0x63, 0x83, 0xc8, 0x59, 0xbb, 0x19, 0x07, 0x77, 0x64, 0x14, 0x1a, 0x21, 0x03, + 0x8e, 0x32, 0x45, 0x44, 0x2d, 0xcb, 0xe4, 0x84, 0x65, 0xd2, 0x08, 0xb9, 0xf7, 0x0b, 0x6c, 0xde, + 0xb9, 0x08, 0x90, 0x26, 0x7c, 0xde, 0x3b, 0xe8, 0x1f, 0x0d, 0x7a, 0x07, 0x2f, 0xbb, 0xce, 0x39, + 0x3d, 0xbb, 0x38, 0xeb, 0x9c, 0x9d, 0x3a, 0xaf, 0xfa, 0x83, 0xf3, 0x6e, 0xe7, 0xe4, 0xf8, 0xa4, + 0x7b, 0x64, 0xaf, 0x90, 0x75, 0xc8, 0x5f, 0x9c, 0x0e, 0x6c, 0x8b, 0x14, 0xa1, 0x70, 0x70, 0x7a, + 0x31, 0xb0, 0x73, 0x7b, 0x5d, 0xa8, 0x2d, 0xdc, 0x90, 0x48, 0x03, 0x1e, 0xf4, 0xbb, 0x17, 0xbf, + 0x9f, 0xd1, 0x97, 0xef, 0xcb, 0xd3, 0x39, 0xb7, 0xad, 0xe0, 0xe1, 0xd5, 0xd1, 0xb9, 0x9d, 0x6b, + 0xbf, 0x9e, 0x5b, 0x92, 0x1c, 0x84, 0x17, 0x26, 0x72, 0x0c, 0xe5, 0x23, 0x91, 0x84, 0xc9, 0x76, + 0xb6, 0x1c, 0xd7, 0xbb, 0xf5, 0x25, 0x3a, 0xf9, 0xcd, 0x95, 0x96, 0xf5, 0x8d, 0x75, 0x58, 0x85, + 0x0d, 0x4f, 0x84, 0x18, 0x36, 0xd6, 0xea, 0xcd, 0x9a, 0xd9, 0x28, 0x4f, 0xfe, 0x0f, 0x00, 0x00, + 0xff, 0xff, 0x3f, 0xb3, 0x37, 0x22, 0x74, 0x0a, 0x00, 0x00, +} diff --git a/credentials/alts/core/proto/handshaker.proto b/credentials/alts/core/proto/handshaker.proto new file mode 100644 index 00000000..42f08c90 --- /dev/null +++ b/credentials/alts/core/proto/handshaker.proto @@ -0,0 +1,220 @@ +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +import "transport_security_common.proto"; + +package grpc.gcp; + +option java_package = "io.grpc.alts"; + +enum HandshakeProtocol { + // Default value. + HANDSHAKE_PROTOCOL_UNSPECIFIED = 0; + + // TLS handshake protocol. + TLS = 1; + + // Application Layer Transport Security handshake protocol. + ALTS = 2; +} + +enum NetworkProtocol { + NETWORK_PROTOCOL_UNSPECIFIED = 0; + TCP = 1; + UDP = 2; +} + +message Endpoint { + // IP address. It should contain an IPv4 or IPv6 string literal, e.g. + // "192.168.0.1" or "2001:db8::1". + string ip_address = 1; + + // Port number. + int32 port = 2; + + // Network protocol (e.g., TCP, UDP) associated with this endpoint. + NetworkProtocol protocol = 3; +} + +message Identity { + oneof identity_oneof { + // Service account of a connection endpoint. + string service_account = 1; + + // Hostname of a connection endpoint. + string hostname = 2; + } +} + +message StartClientHandshakeReq { + // Handshake security protocol requested by the client. + HandshakeProtocol handshake_security_protocol = 1; + + // The application protocols supported by the client, e.g., "h2" (for http2), + // "grpc". + repeated string application_protocols = 2; + + // The record protocols supported by the client, e.g., + // "ALTSRP_GCM_AES128". + repeated string record_protocols = 3; + + // (Optional) Describes which server identities are acceptable by the client. + // If target identities are provided and none of them matches the peer + // identity of the server, handshake will fail. + repeated Identity target_identities = 4; + + // (Optional) Application may specify a local identity. Otherwise, the + // handshaker chooses a default local identity. + Identity local_identity = 5; + + // (Optional) Local endpoint information of the connection to the server, + // such as local IP address, port number, and network protocol. + Endpoint local_endpoint = 6; + + // (Optional) Endpoint information of the remote server, such as IP address, + // port number, and network protocol. + Endpoint remote_endpoint = 7; + + // (Optional) If target name is provided, a secure naming check is performed + // to verify that the peer authenticated identity is indeed authorized to run + // the target name. + string target_name = 8; + + // (Optional) RPC protocol versions supported by the client. + RpcProtocolVersions rpc_versions = 9; +} + +message ServerHandshakeParameters { + // The record protocols supported by the server, e.g., + // "ALTSRP_GCM_AES128". + repeated string record_protocols = 1; + + // (Optional) A list of local identities supported by the server, if + // specified. Otherwise, the handshaker chooses a default local identity. + repeated Identity local_identities = 2; +} + +message StartServerHandshakeReq { + // The application protocols supported by the server, e.g., "h2" (for http2), + // "grpc". + repeated string application_protocols = 1; + + // Handshake parameters (record protocols and local identities supported by + // the server) mapped by the handshake protocol. Each handshake security + // protocol (e.g., TLS or ALTS) has its own set of record protocols and local + // identities. Since protobuf does not support enum as key to the map, the key + // to handshake_parameters is the integer value of HandshakeProtocol enum. + map handshake_parameters = 2; + + // Bytes in out_frames returned from the peer's HandshakerResp. It is possible + // that the peer's out_frames are split into multiple HandshakReq messages. + bytes in_bytes = 3; + + // (Optional) Local endpoint information of the connection to the client, + // such as local IP address, port number, and network protocol. + Endpoint local_endpoint = 4; + + // (Optional) Endpoint information of the remote client, such as IP address, + // port number, and network protocol. + Endpoint remote_endpoint = 5; + + // (Optional) RPC protocol versions supported by the server. + RpcProtocolVersions rpc_versions = 6; +} + +message NextHandshakeMessageReq { + // Bytes in out_frames returned from the peer's HandshakerResp. It is possible + // that the peer's out_frames are split into multiple NextHandshakerMessageReq + // messages. + bytes in_bytes = 1; +} + +message HandshakerReq { + oneof req_oneof { + // The start client handshake request message. + StartClientHandshakeReq client_start = 1; + + // The start server handshake request message. + StartServerHandshakeReq server_start = 2; + + // The next handshake request message. + NextHandshakeMessageReq next = 3; + } +} + +message HandshakerResult { + // The application protocol negotiated for this connection. + string application_protocol = 1; + + // The record protocol negotiated for this connection. + string record_protocol = 2; + + // Cryptographic key data. The key data may be more than the key length + // required for the record protocol, thus the client of the handshaker + // service needs to truncate the key data into the right key length. + bytes key_data = 3; + + // The authenticated identity of the peer. + Identity peer_identity = 4; + + // The local identity used in the handshake. + Identity local_identity = 5; + + // Indicate whether the handshaker service client should keep the channel + // between the handshaker service open, e.g., in order to handle + // post-handshake messages in the future. + bool keep_channel_open = 6; + + // The RPC protocol versions supported by the peer. + RpcProtocolVersions peer_rpc_versions = 7; +} + +message HandshakerStatus { + // The status code. This could be the gRPC status code. + uint32 code = 1; + + // The status details. + string details = 2; +} + +message HandshakerResp { + // Frames to be given to the peer for the NextHandshakeMessageReq. May be + // empty if no out_frames have to be sent to the peer or if in_bytes in the + // HandshakerReq are incomplete. All the non-empty out frames must be sent to + // the peer even if the handshaker status is not OK as these frames may + // contain the alert frames. + bytes out_frames = 1; + + // Number of bytes in the in_bytes consumed by the handshaker. It is possible + // that part of in_bytes in HandshakerReq was unrelated to the handshake + // process. + uint32 bytes_consumed = 2; + + // This is set iff the handshake was successful. out_frames may still be set + // to frames that needs to be forwarded to the peer. + HandshakerResult result = 3; + + // Status of the handshaker. + HandshakerStatus status = 4; +} + +service HandshakerService { + // Accepts a stream of handshaker request, returning a stream of handshaker + // response. + rpc DoHandshake(stream HandshakerReq) + returns (stream HandshakerResp) { + } +} diff --git a/credentials/alts/core/proto/transport_security_common.pb.go b/credentials/alts/core/proto/transport_security_common.pb.go new file mode 100644 index 00000000..16299406 --- /dev/null +++ b/credentials/alts/core/proto/transport_security_common.pb.go @@ -0,0 +1,120 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: transport_security_common.proto + +package grpc_gcp + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// The security level of the created channel. The list is sorted in increasing +// level of security. This order must always be maintained. +type SecurityLevel int32 + +const ( + SecurityLevel_SECURITY_NONE SecurityLevel = 0 + SecurityLevel_INTEGRITY_ONLY SecurityLevel = 1 + SecurityLevel_INTEGRITY_AND_PRIVACY SecurityLevel = 2 +) + +var SecurityLevel_name = map[int32]string{ + 0: "SECURITY_NONE", + 1: "INTEGRITY_ONLY", + 2: "INTEGRITY_AND_PRIVACY", +} +var SecurityLevel_value = map[string]int32{ + "SECURITY_NONE": 0, + "INTEGRITY_ONLY": 1, + "INTEGRITY_AND_PRIVACY": 2, +} + +func (x SecurityLevel) String() string { + return proto.EnumName(SecurityLevel_name, int32(x)) +} +func (SecurityLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor2, []int{0} } + +// Max and min supported RPC protocol versions. +type RpcProtocolVersions struct { + // Maximum supported RPC version. + MaxRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,1,opt,name=max_rpc_version,json=maxRpcVersion" json:"max_rpc_version,omitempty"` + // Minimum supported RPC version. + MinRpcVersion *RpcProtocolVersions_Version `protobuf:"bytes,2,opt,name=min_rpc_version,json=minRpcVersion" json:"min_rpc_version,omitempty"` +} + +func (m *RpcProtocolVersions) Reset() { *m = RpcProtocolVersions{} } +func (m *RpcProtocolVersions) String() string { return proto.CompactTextString(m) } +func (*RpcProtocolVersions) ProtoMessage() {} +func (*RpcProtocolVersions) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{0} } + +func (m *RpcProtocolVersions) GetMaxRpcVersion() *RpcProtocolVersions_Version { + if m != nil { + return m.MaxRpcVersion + } + return nil +} + +func (m *RpcProtocolVersions) GetMinRpcVersion() *RpcProtocolVersions_Version { + if m != nil { + return m.MinRpcVersion + } + return nil +} + +// RPC version contains a major version and a minor version. +type RpcProtocolVersions_Version struct { + Major uint32 `protobuf:"varint,1,opt,name=major" json:"major,omitempty"` + Minor uint32 `protobuf:"varint,2,opt,name=minor" json:"minor,omitempty"` +} + +func (m *RpcProtocolVersions_Version) Reset() { *m = RpcProtocolVersions_Version{} } +func (m *RpcProtocolVersions_Version) String() string { return proto.CompactTextString(m) } +func (*RpcProtocolVersions_Version) ProtoMessage() {} +func (*RpcProtocolVersions_Version) Descriptor() ([]byte, []int) { return fileDescriptor2, []int{0, 0} } + +func (m *RpcProtocolVersions_Version) GetMajor() uint32 { + if m != nil { + return m.Major + } + return 0 +} + +func (m *RpcProtocolVersions_Version) GetMinor() uint32 { + if m != nil { + return m.Minor + } + return 0 +} + +func init() { + proto.RegisterType((*RpcProtocolVersions)(nil), "grpc.gcp.RpcProtocolVersions") + proto.RegisterType((*RpcProtocolVersions_Version)(nil), "grpc.gcp.RpcProtocolVersions.Version") + proto.RegisterEnum("grpc.gcp.SecurityLevel", SecurityLevel_name, SecurityLevel_value) +} + +func init() { proto.RegisterFile("transport_security_common.proto", fileDescriptor2) } + +var fileDescriptor2 = []byte{ + // 261 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x2f, 0x29, 0x4a, 0xcc, + 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0x89, 0x2f, 0x4e, 0x4d, 0x2e, 0x2d, 0xca, 0x2c, 0xa9, 0x8c, 0x4f, + 0xce, 0xcf, 0xcd, 0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x48, 0x2f, 0x2a, + 0x48, 0xd6, 0x4b, 0x4f, 0x2e, 0x50, 0x7a, 0xc5, 0xc8, 0x25, 0x1c, 0x54, 0x90, 0x1c, 0x00, 0x12, + 0x4e, 0xce, 0xcf, 0x09, 0x4b, 0x2d, 0x2a, 0xce, 0xcc, 0xcf, 0x2b, 0x16, 0xf2, 0xe5, 0xe2, 0xcf, + 0x4d, 0xac, 0x88, 0x2f, 0x2a, 0x48, 0x8e, 0x2f, 0x83, 0x88, 0x49, 0x30, 0x2a, 0x30, 0x6a, 0x70, + 0x1b, 0xa9, 0xea, 0xc1, 0xf4, 0xea, 0x61, 0xd1, 0xa7, 0x07, 0x65, 0x04, 0xf1, 0xe6, 0x26, 0x56, + 0x04, 0x15, 0x24, 0x43, 0xb9, 0x60, 0xe3, 0x32, 0xf3, 0x50, 0x8c, 0x63, 0x22, 0xcd, 0xb8, 0xcc, + 0x3c, 0x84, 0x71, 0x52, 0xa6, 0x5c, 0xec, 0x30, 0x93, 0x45, 0xb8, 0x58, 0x73, 0x13, 0xb3, 0xf2, + 0x8b, 0xc0, 0xce, 0xe3, 0x0d, 0x82, 0x70, 0xc0, 0xa2, 0x99, 0x79, 0xf9, 0x45, 0x60, 0x5b, 0x40, + 0xa2, 0x20, 0x8e, 0x56, 0x20, 0x17, 0x6f, 0x30, 0x34, 0x3c, 0x7c, 0x52, 0xcb, 0x52, 0x73, 0x84, + 0x04, 0xb9, 0x78, 0x83, 0x5d, 0x9d, 0x43, 0x83, 0x3c, 0x43, 0x22, 0xe3, 0xfd, 0xfc, 0xfd, 0x5c, + 0x05, 0x18, 0x84, 0x84, 0xb8, 0xf8, 0x3c, 0xfd, 0x42, 0x5c, 0xdd, 0xc1, 0x62, 0xfe, 0x7e, 0x3e, + 0x91, 0x02, 0x8c, 0x42, 0x92, 0x5c, 0xa2, 0x08, 0x31, 0x47, 0x3f, 0x97, 0xf8, 0x80, 0x20, 0xcf, + 0x30, 0x47, 0xe7, 0x48, 0x01, 0x26, 0x27, 0x3e, 0x2e, 0x9e, 0xcc, 0x7c, 0x88, 0x1f, 0x12, 0x73, + 0x4a, 0x8a, 0x93, 0xd8, 0xc0, 0x01, 0x6c, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x11, 0x06, 0x14, + 0x7a, 0x83, 0x01, 0x00, 0x00, +} diff --git a/credentials/alts/core/proto/transport_security_common.proto b/credentials/alts/core/proto/transport_security_common.proto new file mode 100644 index 00000000..41983ab9 --- /dev/null +++ b/credentials/alts/core/proto/transport_security_common.proto @@ -0,0 +1,40 @@ +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package grpc.gcp; + +option java_package = "io.grpc.alts"; + +// The security level of the created channel. The list is sorted in increasing +// level of security. This order must always be maintained. +enum SecurityLevel { + SECURITY_NONE = 0; + INTEGRITY_ONLY = 1; + INTEGRITY_AND_PRIVACY = 2; +} + +// Max and min supported RPC protocol versions. +message RpcProtocolVersions { + // RPC version contains a major version and a minor version. + message Version { + uint32 major = 1; + uint32 minor = 2; + } + // Maximum supported RPC version. + Version max_rpc_version = 1; + // Minimum supported RPC version. + Version min_rpc_version = 2; +} diff --git a/credentials/alts/core/testutil/testutil.go b/credentials/alts/core/testutil/testutil.go new file mode 100644 index 00000000..91cbd039 --- /dev/null +++ b/credentials/alts/core/testutil/testutil.go @@ -0,0 +1,125 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package testutil include useful test utilities for the handshaker. +package testutil + +import ( + "bytes" + "encoding/binary" + "io" + "net" + "sync" + + "google.golang.org/grpc/credentials/alts/core/conn" +) + +// Stats is used to collect statistics about concurrent handshake calls. +type Stats struct { + mu sync.Mutex + calls int + MaxConcurrentCalls int +} + +// Update updates the statistics by adding one call. +func (s *Stats) Update() func() { + s.mu.Lock() + s.calls++ + if s.calls > s.MaxConcurrentCalls { + s.MaxConcurrentCalls = s.calls + } + s.mu.Unlock() + + return func() { + s.mu.Lock() + s.calls-- + s.mu.Unlock() + } +} + +// Reset resets the statistics. +func (s *Stats) Reset() { + s.mu.Lock() + defer s.mu.Unlock() + s.calls = 0 + s.MaxConcurrentCalls = 0 +} + +// testConn mimics a net.Conn to the peer. +type testConn struct { + net.Conn + in *bytes.Buffer + out *bytes.Buffer +} + +// NewTestConn creates a new instance of testConn object. +func NewTestConn(in *bytes.Buffer, out *bytes.Buffer) net.Conn { + return &testConn{ + in: in, + out: out, + } +} + +// Read reads from the in buffer. +func (c *testConn) Read(b []byte) (n int, err error) { + return c.in.Read(b) +} + +// Write writes to the out buffer. +func (c *testConn) Write(b []byte) (n int, err error) { + return c.out.Write(b) +} + +// Close closes the testConn object. +func (c *testConn) Close() error { + return nil +} + +// unresponsiveTestConn mimics a net.Conn for an unresponsive peer. It is used +// for testing the PeerNotResponding case. +type unresponsiveTestConn struct { + net.Conn +} + +// NewUnresponsiveTestConn creates a new instance of unresponsiveTestConn object. +func NewUnresponsiveTestConn() net.Conn { + return &unresponsiveTestConn{} +} + +// Read reads from the in buffer. +func (c *unresponsiveTestConn) Read(b []byte) (n int, err error) { + return 0, io.EOF +} + +// Write writes to the out buffer. +func (c *unresponsiveTestConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +// Close closes the TestConn object. +func (c *unresponsiveTestConn) Close() error { + return nil +} + +// MakeFrame creates a handshake frame. +func MakeFrame(pl string) []byte { + f := make([]byte, len(pl)+conn.MsgLenFieldSize) + binary.LittleEndian.PutUint32(f, uint32(len(pl))) + copy(f[conn.MsgLenFieldSize:], []byte(pl)) + return f +} diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go new file mode 100644 index 00000000..cd5be2e6 --- /dev/null +++ b/credentials/alts/utils.go @@ -0,0 +1,117 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package alts + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "regexp" + "runtime" + "strings" + "time" +) + +const ( + linuxProductNameFile = "/sys/class/dmi/id/product_name" + windowsCheckCommand = "powershell.exe" + windowsCheckCommandArgs = "Get-WmiObject -Class Win32_BIOS" + powershellOutputFilter = "Manufacturer" + windowsManufacturerRegex = ":(.*)" + windowsCheckTimeout = 30 * time.Second +) + +type platformError string + +func (k platformError) Error() string { + return fmt.Sprintf("%v is not supported", k) +} + +var ( + // The following two variables will be reassigned in tests. + runningOS = runtime.GOOS + manufacturerReader = func() (io.Reader, error) { + switch runningOS { + case "linux": + return os.Open(linuxProductNameFile) + case "windows": + cmd := exec.Command(windowsCheckCommand, windowsCheckCommandArgs) + out, err := cmd.Output() + if err != nil { + return nil, err + } + + for _, line := range strings.Split(strings.TrimSuffix(string(out), "\n"), "\n") { + if strings.HasPrefix(line, powershellOutputFilter) { + re := regexp.MustCompile(windowsManufacturerRegex) + name := re.FindString(line) + name = strings.TrimLeft(name, ":") + return strings.NewReader(name), nil + } + } + + return nil, errors.New("cannot determine the machine's manufacturer") + default: + return nil, platformError(runningOS) + } + } + vmOnGCP bool +) + +// isRunningOnGCP checks whether the local system, without doing a network request is +// running on GCP. +func isRunningOnGCP() bool { + manufacturer, err := readManufacturer() + if err != nil { + log.Fatalf("failure to read manufacturer information: %v", err) + } + name := string(manufacturer) + switch runningOS { + case "linux": + name = strings.TrimSpace(name) + return name == "Google" || name == "Google Compute Engine" + case "windows": + name = strings.Replace(name, " ", "", -1) + name = strings.Replace(name, "\n", "", -1) + name = strings.Replace(name, "\r", "", -1) + return name == "Google" + default: + log.Fatal(platformError(runningOS)) + } + return false +} + +func readManufacturer() ([]byte, error) { + reader, err := manufacturerReader() + if err != nil { + return nil, err + } + if reader == nil { + return nil, errors.New("got nil reader") + } + manufacturer, err := ioutil.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed reading %v: %v", linuxProductNameFile, err) + } + return manufacturer, nil +} diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go new file mode 100644 index 00000000..32c5e1bf --- /dev/null +++ b/credentials/alts/utils_test.go @@ -0,0 +1,66 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package alts + +import ( + "io" + "strings" + "testing" +) + +func TestIsRunningOnGCP(t *testing.T) { + for _, tc := range []struct { + description string + testOS string + testReader io.Reader + out bool + }{ + // Linux tests. + {"linux: not a GCP platform", "linux", strings.NewReader("not GCP"), false}, + {"Linux: GCP platform (Google)", "linux", strings.NewReader("Google"), true}, + {"Linux: GCP platform (Google Compute Engine)", "linux", strings.NewReader("Google Compute Engine"), true}, + {"Linux: GCP platform (Google Compute Engine) with extra spaces", "linux", strings.NewReader(" Google Compute Engine "), true}, + // Windows tests. + {"windows: not a GCP platform", "windows", strings.NewReader("not GCP"), false}, + {"windows: GCP platform (Google)", "windows", strings.NewReader("Google"), true}, + {"windows: GCP platform (Google) with extra spaces", "windows", strings.NewReader(" Google "), true}, + } { + reverseFunc := setup(tc.testOS, tc.testReader) + if got, want := isRunningOnGCP(), tc.out; got != want { + t.Errorf("%v: isRunningOnGCP()=%v, want %v", tc.description, got, want) + } + reverseFunc() + } +} + +func setup(testOS string, testReader io.Reader) func() { + tmpOS := runningOS + tmpReader := manufacturerReader + + // Set test OS and reader function. + runningOS = testOS + manufacturerReader = func() (io.Reader, error) { + return testReader, nil + } + + return func() { + runningOS = tmpOS + manufacturerReader = tmpReader + } +} diff --git a/interop/alts/client/client.go b/interop/alts/client/client.go new file mode 100644 index 00000000..0b61470a --- /dev/null +++ b/interop/alts/client/client.go @@ -0,0 +1,65 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This binary can only run on Google Cloud Platform (GCP). + +package main + +import ( + "flag" + "time" + + "golang.org/x/net/context" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/credentials/alts" + "google.golang.org/grpc/grpclog" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +const ( + value = "test_value" +) + +var ( + serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening") +) + +func main() { + flag.Parse() + + altsTC := alts.NewClientALTS(nil) + // Block until the server is ready. + conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock()) + if err != nil { + grpclog.Fatalf("gRPC Client: failed to dial the server at %v: %v", *serverAddr, err) + } + defer conn.Close() + grpcClient := testpb.NewTestServiceClient(conn) + + // Call the EmptyCall API. + ctx := context.Background() + request := &testpb.Empty{} + if _, err := grpcClient.EmptyCall(ctx, request); err != nil { + grpclog.Fatalf("grpc Client: EmptyCall(_, %v) failed: %v", request, err) + } + grpclog.Info("grpc Client: empty call succeeded") + + // This sleep prevents the connection from being abruptly disconnected + // when running this binary (along with grpc_server) on GCP dev cluster. + time.Sleep(1 * time.Second) +} diff --git a/interop/alts/server/server.go b/interop/alts/server/server.go new file mode 100644 index 00000000..dcbf2c79 --- /dev/null +++ b/interop/alts/server/server.go @@ -0,0 +1,49 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This binary can only run on Google Cloud Platform (GCP). + +package main + +import ( + "flag" + "net" + + grpc "google.golang.org/grpc" + "google.golang.org/grpc/credentials/alts" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/interop" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +var ( + serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening") +) + +func main() { + flag.Parse() + + lis, err := net.Listen("tcp", *serverAddr) + if err != nil { + grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err) + } + altsTC := alts.NewServerALTS() + grpcServer := grpc.NewServer(grpc.Creds(altsTC)) + testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer()) + grpcServer.Serve(lis) +}