diff --git a/rpc_util.go b/rpc_util.go
index cf9dbe7f..8644b8a7 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -155,7 +155,6 @@ func (d *gzipDecompressor) Type() string {
 type callInfo struct {
 	compressorType        string
 	failFast              bool
-	stream                ClientStream
 	maxReceiveMessageSize *int
 	maxSendMessageSize    *int
 	creds                 credentials.PerRPCCredentials
@@ -180,7 +179,7 @@ type CallOption interface {
 
 	// after is called after the call has completed.  after cannot return an
 	// error, so any failures should be reported via output parameters.
-	after(*callInfo)
+	after(*callInfo, *csAttempt)
 }
 
 // EmptyCallOption does not alter the Call configuration.
@@ -188,8 +187,8 @@ type CallOption interface {
 // by interceptors.
 type EmptyCallOption struct{}
 
-func (EmptyCallOption) before(*callInfo) error { return nil }
-func (EmptyCallOption) after(*callInfo)        {}
+func (EmptyCallOption) before(*callInfo) error      { return nil }
+func (EmptyCallOption) after(*callInfo, *csAttempt) {}
 
 // Header returns a CallOptions that retrieves the header metadata
 // for a unary RPC.
@@ -205,10 +204,8 @@ type HeaderCallOption struct {
 }
 
 func (o HeaderCallOption) before(c *callInfo) error { return nil }
-func (o HeaderCallOption) after(c *callInfo) {
-	if c.stream != nil {
-		*o.HeaderAddr, _ = c.stream.Header()
-	}
+func (o HeaderCallOption) after(c *callInfo, attempt *csAttempt) {
+	*o.HeaderAddr, _ = attempt.s.Header()
 }
 
 // Trailer returns a CallOptions that retrieves the trailer metadata
@@ -225,10 +222,8 @@ type TrailerCallOption struct {
 }
 
 func (o TrailerCallOption) before(c *callInfo) error { return nil }
-func (o TrailerCallOption) after(c *callInfo) {
-	if c.stream != nil {
-		*o.TrailerAddr = c.stream.Trailer()
-	}
+func (o TrailerCallOption) after(c *callInfo, attempt *csAttempt) {
+	*o.TrailerAddr = attempt.s.Trailer()
 }
 
 // Peer returns a CallOption that retrieves peer information for a unary RPC.
@@ -245,11 +240,9 @@ type PeerCallOption struct {
 }
 
 func (o PeerCallOption) before(c *callInfo) error { return nil }
-func (o PeerCallOption) after(c *callInfo) {
-	if c.stream != nil {
-		if x, ok := peer.FromContext(c.stream.Context()); ok {
-			*o.PeerAddr = *x
-		}
+func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
+	if x, ok := peer.FromContext(attempt.s.Context()); ok {
+		*o.PeerAddr = *x
 	}
 }
 
@@ -285,7 +278,7 @@ func (o FailFastCallOption) before(c *callInfo) error {
 	c.failFast = o.FailFast
 	return nil
 }
-func (o FailFastCallOption) after(c *callInfo) {}
+func (o FailFastCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // MaxCallRecvMsgSize returns a CallOption which sets the maximum message size
 // in bytes the client can receive.
@@ -304,7 +297,7 @@ func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
 	c.maxReceiveMessageSize = &o.MaxRecvMsgSize
 	return nil
 }
-func (o MaxRecvMsgSizeCallOption) after(c *callInfo) {}
+func (o MaxRecvMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // MaxCallSendMsgSize returns a CallOption which sets the maximum message size
 // in bytes the client can send.
@@ -323,7 +316,7 @@ func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
 	c.maxSendMessageSize = &o.MaxSendMsgSize
 	return nil
 }
-func (o MaxSendMsgSizeCallOption) after(c *callInfo) {}
+func (o MaxSendMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
 // for a call.
@@ -342,7 +335,7 @@ func (o PerRPCCredsCallOption) before(c *callInfo) error {
 	c.creds = o.Creds
 	return nil
 }
-func (o PerRPCCredsCallOption) after(c *callInfo) {}
+func (o PerRPCCredsCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // UseCompressor returns a CallOption which sets the compressor used when
 // sending the request.  If WithCompressor is also set, UseCompressor has
@@ -363,7 +356,7 @@ func (o CompressorCallOption) before(c *callInfo) error {
 	c.compressorType = o.CompressorType
 	return nil
 }
-func (o CompressorCallOption) after(c *callInfo) {}
+func (o CompressorCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // CallContentSubtype returns a CallOption that will set the content-subtype
 // for a call. For example, if content-subtype is "json", the Content-Type over
@@ -396,7 +389,7 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
 	c.contentSubtype = o.ContentSubtype
 	return nil
 }
-func (o ContentSubtypeCallOption) after(c *callInfo) {}
+func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // ForceCodec returns a CallOption that will set the given Codec to be
 // used for all request and response messages for a call. The result of calling
@@ -428,7 +421,7 @@ func (o ForceCodecCallOption) before(c *callInfo) error {
 	c.codec = o.Codec
 	return nil
 }
-func (o ForceCodecCallOption) after(c *callInfo) {}
+func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
 // an encoding.Codec.
@@ -450,7 +443,7 @@ func (o CustomCodecCallOption) before(c *callInfo) error {
 	c.codec = o.Codec
 	return nil
 }
-func (o CustomCodecCallOption) after(c *callInfo) {}
+func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // MaxRetryRPCBufferSize returns a CallOption that limits the amount of memory
 // used for buffering this RPC's requests for retry purposes.
@@ -471,7 +464,7 @@ func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error {
 	c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize
 	return nil
 }
-func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo) {}
+func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo, attempt *csAttempt) {}
 
 // The format of the payload: compressed or not?
 type payloadFormat uint8
diff --git a/stream.go b/stream.go
index 934ef683..62d51334 100644
--- a/stream.go
+++ b/stream.go
@@ -277,7 +277,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 	}
 	cs.binlog = binarylog.GetMethodLogger(method)
 
-	cs.callInfo.stream = cs
 	// Only this initial attempt has stats/tracing.
 	// TODO(dfawley): move to newAttempt when per-attempt stats are implemented.
 	if err := cs.newAttemptLocked(sh, trInfo); err != nil {
@@ -799,6 +798,15 @@ func (cs *clientStream) finish(err error) {
 	}
 	cs.finished = true
 	cs.commitAttemptLocked()
+	if cs.attempt != nil {
+		cs.attempt.finish(err)
+		// after functions all rely upon having a stream.
+		if cs.attempt.s != nil {
+			for _, o := range cs.opts {
+				o.after(cs.callInfo, cs.attempt)
+			}
+		}
+	}
 	cs.mu.Unlock()
 	// For binary logging. only log cancel in finish (could be caused by RPC ctx
 	// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
@@ -820,15 +828,6 @@ func (cs *clientStream) finish(err error) {
 			cs.cc.incrCallsSucceeded()
 		}
 	}
-	if cs.attempt != nil {
-		cs.attempt.finish(err)
-		// after functions all rely upon having a stream.
-		if cs.attempt.s != nil {
-			for _, o := range cs.opts {
-				o.after(cs.callInfo)
-			}
-		}
-	}
 	cs.cancel()
 }
 
@@ -1066,7 +1065,6 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
 		t:        t,
 	}
 
-	as.callInfo.stream = as
 	s, err := as.t.NewStream(as.ctx, as.callHdr)
 	if err != nil {
 		err = toRPCErr(err)
diff --git a/test/end2end_test.go b/test/end2end_test.go
index f3a60de5..82150fe1 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -7128,3 +7128,62 @@ func (s) TestGzipBadChecksum(t *testing.T) {
 		t.Errorf("ss.client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
 	}
 }
+
+// When an RPC is canceled, it's possible that the last Recv() returns before
+// all call options' after are executed.
+func (s) TestCanceledRPCCallOptionRace(t *testing.T) {
+	ss := &stubServer{
+		fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
+			err := stream.Send(&testpb.StreamingOutputCallResponse{})
+			if err != nil {
+				return err
+			}
+			<-stream.Context().Done()
+			return nil
+		},
+	}
+	if err := ss.Start(nil); err != nil {
+		t.Fatalf("Error starting endpoint server: %v", err)
+	}
+	defer ss.Stop()
+
+	const count = 1000
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+
+	var wg sync.WaitGroup
+	wg.Add(count)
+	for i := 0; i < count; i++ {
+		go func() {
+			defer wg.Done()
+			var p peer.Peer
+			ctx, cancel := context.WithCancel(ctx)
+			defer cancel()
+			stream, err := ss.client.FullDuplexCall(ctx, grpc.Peer(&p))
+			if err != nil {
+				t.Errorf("_.FullDuplexCall(_) = _, %v", err)
+				return
+			}
+			if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
+				t.Errorf("_ has error %v while sending", err)
+				return
+			}
+			if _, err := stream.Recv(); err != nil {
+				t.Errorf("%v.Recv() = %v", stream, err)
+				return
+			}
+			cancel()
+			if _, err := stream.Recv(); status.Code(err) != codes.Canceled {
+				t.Errorf("%v compleled with error %v, want %s", stream, err, codes.Canceled)
+				return
+			}
+			// If recv returns before call options are executed, peer.Addr is not set,
+			// fail the test.
+			if p.Addr == nil {
+				t.Errorf("peer.Addr is nil, want non-nil")
+				return
+			}
+		}()
+	}
+	wg.Wait()
+}