vendor: google.golang.org/grpc v1.66.2

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
(cherry picked from commit b6d27ff60e)
Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2024-10-12 21:52:41 +02:00
parent 5cbb4ca191
commit 2668b11ce4
No known key found for this signature in database
GPG Key ID: 76698F39D527CE8C
100 changed files with 5311 additions and 3924 deletions

View File

@ -97,8 +97,8 @@ require (
golang.org/x/crypto v0.27.0 // indirect golang.org/x/crypto v0.27.0 // indirect
golang.org/x/net v0.29.0 // indirect golang.org/x/net v0.29.0 // indirect
golang.org/x/time v0.6.0 // indirect golang.org/x/time v0.6.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect
google.golang.org/grpc v1.62.0 // indirect google.golang.org/grpc v1.66.2 // indirect
google.golang.org/protobuf v1.34.1 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )

View File

@ -105,8 +105,8 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4=
github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@ -414,15 +414,13 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 h1:+rdxYoE3E5htTEWIe15GlN6IfvbURM//Jt0mmkmm6ZU=
google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80/go.mod h1:cc8bqMqtv9gMOr0zHg2Vzff5ULhhL2IXP4sbcn32Dro= google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117/go.mod h1:OimBR/bc1wPO9iV4NC2bpyjy3VnAwZh5EBPQdtaE5oo=
google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 h1:Lj5rbfG876hIAYFjqiJnPHfhXbv+nzTWfm04Fg/XSVU= google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU=
google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 h1:AjyfHzEPEFp/NpvfN5g+KDla3EMojjhRVZc1i7cj+oM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s=
google.golang.org/grpc v1.0.5/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.0.5/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk= google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo=
google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=

View File

@ -1,4 +1,4 @@
// Copyright 2023 Google LLC // Copyright 2024 Google LLC
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.9 // protoc v4.24.4
// source: google/api/httpbody.proto // source: google/api/httpbody.proto
package httpbody package httpbody

View File

@ -1,4 +1,4 @@
// Copyright 2022 Google LLC // Copyright 2024 Google LLC
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.9 // protoc v4.24.4
// source: google/rpc/error_details.proto // source: google/rpc/error_details.proto
package errdetails package errdetails

View File

@ -1,4 +1,4 @@
// Copyright 2022 Google LLC // Copyright 2024 Google LLC
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.26.0 // protoc-gen-go v1.26.0
// protoc v3.21.9 // protoc v4.24.4
// source: google/rpc/status.proto // source: google/rpc/status.proto
package status package status

View File

@ -66,7 +66,7 @@ How to get your contributions merged smoothly and quickly.
- **All tests need to be passing** before your change can be merged. We - **All tests need to be passing** before your change can be merged. We
recommend you **run tests locally** before creating your PR to catch breakages recommend you **run tests locally** before creating your PR to catch breakages
early on. early on.
- `VET_SKIP_PROTO=1 ./vet.sh` to catch vet errors - `./scripts/vet.sh` to catch vet errors
- `go test -cpu 1,4 -timeout 7m ./...` to run the tests - `go test -cpu 1,4 -timeout 7m ./...` to run the tests
- `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode - `go test -race -cpu 1,4 -timeout 7m ./...` to run tests in race mode

View File

@ -9,20 +9,28 @@ for general contribution guidelines.
## Maintainers (in alphabetical order) ## Maintainers (in alphabetical order)
- [cesarghali](https://github.com/cesarghali), Google LLC - [aranjans](https://github.com/aranjans), Google LLC
- [arjan-bal](https://github.com/arjan-bal), Google LLC
- [arvindbr8](https://github.com/arvindbr8), Google LLC
- [atollena](https://github.com/atollena), Datadog, Inc.
- [dfawley](https://github.com/dfawley), Google LLC - [dfawley](https://github.com/dfawley), Google LLC
- [easwars](https://github.com/easwars), Google LLC - [easwars](https://github.com/easwars), Google LLC
- [menghanl](https://github.com/menghanl), Google LLC - [erm-g](https://github.com/erm-g), Google LLC
- [srini100](https://github.com/srini100), Google LLC - [gtcooke94](https://github.com/gtcooke94), Google LLC
- [purnesh42h](https://github.com/purnesh42h), Google LLC
- [zasweq](https://github.com/zasweq), Google LLC
## Emeritus Maintainers (in alphabetical order) ## Emeritus Maintainers (in alphabetical order)
- [adelez](https://github.com/adelez), Google LLC - [adelez](https://github.com/adelez)
- [canguler](https://github.com/canguler), Google LLC - [canguler](https://github.com/canguler)
- [iamqizhao](https://github.com/iamqizhao), Google LLC - [cesarghali](https://github.com/cesarghali)
- [jadekler](https://github.com/jadekler), Google LLC - [iamqizhao](https://github.com/iamqizhao)
- [jtattermusch](https://github.com/jtattermusch), Google LLC - [jeanbza](https://github.com/jeanbza)
- [lyuxuan](https://github.com/lyuxuan), Google LLC - [jtattermusch](https://github.com/jtattermusch)
- [makmukhi](https://github.com/makmukhi), Google LLC - [lyuxuan](https://github.com/lyuxuan)
- [matt-kwong](https://github.com/matt-kwong), Google LLC - [makmukhi](https://github.com/makmukhi)
- [nicolasnoble](https://github.com/nicolasnoble), Google LLC - [matt-kwong](https://github.com/matt-kwong)
- [yongni](https://github.com/yongni), Google LLC - [menghanl](https://github.com/menghanl)
- [nicolasnoble](https://github.com/nicolasnoble)
- [srini100](https://github.com/srini100)
- [yongni](https://github.com/yongni)

View File

@ -30,17 +30,20 @@ testdeps:
GO111MODULE=on go get -d -v -t google.golang.org/grpc/... GO111MODULE=on go get -d -v -t google.golang.org/grpc/...
vet: vetdeps vet: vetdeps
./vet.sh ./scripts/vet.sh
vetdeps: vetdeps:
./vet.sh -install ./scripts/vet.sh -install
.PHONY: \ .PHONY: \
all \ all \
build \ build \
clean \ clean \
deps \
proto \ proto \
test \ test \
testsubmodule \
testrace \ testrace \
testdeps \
vet \ vet \
vetdeps vetdeps

View File

@ -10,7 +10,7 @@ RPC framework that puts mobile and HTTP/2 first. For more information see the
## Prerequisites ## Prerequisites
- **[Go][]**: any one of the **three latest major** [releases][go-releases]. - **[Go][]**: any one of the **two latest major** [releases][go-releases].
## Installation ## Installation

View File

@ -1,3 +1,3 @@
# Security Policy # Security Policy
For information on gRPC Security Policy and reporting potentional security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md). For information on gRPC Security Policy and reporting potential security issues, please see [gRPC CVE Process](https://github.com/grpc/proposal/blob/master/P4-grpc-cve-process.md).

View File

@ -39,7 +39,7 @@ type Config struct {
MaxDelay time.Duration MaxDelay time.Duration
} }
// DefaultConfig is a backoff configuration with the default values specfied // DefaultConfig is a backoff configuration with the default values specified
// at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md. // at https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md.
// //
// This should be useful for callers who want to configure backoff with // This should be useful for callers who want to configure backoff with

View File

@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/channelz" "google.golang.org/grpc/channelz"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -54,13 +55,14 @@ var (
// an init() function), and is not thread-safe. If multiple Balancers are // an init() function), and is not thread-safe. If multiple Balancers are
// registered with the same name, the one registered last will take effect. // registered with the same name, the one registered last will take effect.
func Register(b Builder) { func Register(b Builder) {
if strings.ToLower(b.Name()) != b.Name() { name := strings.ToLower(b.Name())
if name != b.Name() {
// TODO: Skip the use of strings.ToLower() to index the map after v1.59 // TODO: Skip the use of strings.ToLower() to index the map after v1.59
// is released to switch to case sensitive balancer registry. Also, // is released to switch to case sensitive balancer registry. Also,
// remove this warning and update the docstrings for Register and Get. // remove this warning and update the docstrings for Register and Get.
logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon", b.Name()) logger.Warningf("Balancer registered with name %q. grpc-go will be switching to case sensitive balancer registries soon", b.Name())
} }
m[strings.ToLower(b.Name())] = b m[name] = b
} }
// unregisterForTesting deletes the balancer with the given name from the // unregisterForTesting deletes the balancer with the given name from the
@ -71,8 +73,21 @@ func unregisterForTesting(name string) {
delete(m, name) delete(m, name)
} }
// connectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
func connectedAddress(scs SubConnState) resolver.Address {
return scs.connectedAddress
}
// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}
func init() { func init() {
internal.BalancerUnregister = unregisterForTesting internal.BalancerUnregister = unregisterForTesting
internal.ConnectedAddress = connectedAddress
internal.SetConnectedAddress = setConnectedAddress
} }
// Get returns the resolver builder registered with the given name. // Get returns the resolver builder registered with the given name.
@ -232,8 +247,8 @@ type BuildOptions struct {
// implementations which do not communicate with a remote load balancer // implementations which do not communicate with a remote load balancer
// server can ignore this field. // server can ignore this field.
Authority string Authority string
// ChannelzParentID is the parent ClientConn's channelz ID. // ChannelzParent is the parent ClientConn's channelz channel.
ChannelzParentID *channelz.Identifier ChannelzParent channelz.Identifier
// CustomUserAgent is the custom user agent set on the parent ClientConn. // CustomUserAgent is the custom user agent set on the parent ClientConn.
// The balancer should set the same custom user agent if it creates a // The balancer should set the same custom user agent if it creates a
// ClientConn. // ClientConn.
@ -242,6 +257,10 @@ type BuildOptions struct {
// same resolver.Target as passed to the resolver. See the documentation for // same resolver.Target as passed to the resolver. See the documentation for
// the resolver.Target type for details about what it contains. // the resolver.Target type for details about what it contains.
Target resolver.Target Target resolver.Target
// MetricsRecorder is the metrics recorder that balancers can use to record
// metrics. Balancer implementations which do not register metrics on
// metrics registry and record on them can ignore this field.
MetricsRecorder estats.MetricsRecorder
} }
// Builder creates a balancer. // Builder creates a balancer.
@ -409,6 +428,9 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure, // ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil. // describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is
// Ready. Otherwise, it is indeterminate.
connectedAddress resolver.Address
} }
// ClientConnState describes the state of a ClientConn relevant to the // ClientConnState describes the state of a ClientConn relevant to the

View File

@ -16,54 +16,60 @@
* *
*/ */
package grpc // Package pickfirst contains the pick_first load balancing policy.
package pickfirst
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
internalgrpclog "google.golang.org/grpc/internal/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
func init() {
balancer.Register(pickfirstBuilder{})
internal.ShuffleAddressListForTesting = func(n int, swap func(i, j int)) { rand.Shuffle(n, swap) }
}
var logger = grpclog.Component("pick-first-lb")
const ( const (
// PickFirstBalancerName is the name of the pick_first balancer. // Name is the name of the pick_first balancer.
PickFirstBalancerName = "pick_first" Name = "pick_first"
logPrefix = "[pick-first-lb %p] " logPrefix = "[pick-first-lb %p] "
) )
func newPickfirstBuilder() balancer.Builder {
return &pickfirstBuilder{}
}
type pickfirstBuilder struct{} type pickfirstBuilder struct{}
func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { func (pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{cc: cc} b := &pickfirstBalancer{cc: cc}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b return b
} }
func (*pickfirstBuilder) Name() string { func (pickfirstBuilder) Name() string {
return PickFirstBalancerName return Name
} }
type pfConfig struct { type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"` serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list // If set to true, instructs the LB policy to shuffle the order of the list
// of addresses received from the name resolver before attempting to // of endpoints received from the name resolver before attempting to
// connect to them. // connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"` ShuffleAddressList bool `json:"shuffleAddressList"`
} }
func (*pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil { if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
@ -97,9 +103,14 @@ func (b *pickfirstBalancer) ResolverError(err error) {
}) })
} }
type Shuffler interface {
ShuffleAddressListForTesting(n int, swap func(i, j int))
}
func ShuffleAddressListForTesting(n int, swap func(i, j int)) { rand.Shuffle(n, swap) }
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
addrs := state.ResolverState.Addresses if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
if len(addrs) == 0 {
// The resolver reported an empty address list. Treat it like an error by // The resolver reported an empty address list. Treat it like an error by
// calling b.ResolverError. // calling b.ResolverError.
if b.subConn != nil { if b.subConn != nil {
@ -111,22 +122,49 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
b.ResolverError(errors.New("produced zero addresses")) b.ResolverError(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState return balancer.ErrBadResolverState
} }
// We don't have to guard this block with the env var because ParseConfig // We don't have to guard this block with the env var because ParseConfig
// already does so. // already does so.
cfg, ok := state.BalancerConfig.(pfConfig) cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok { if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig) return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig)
} }
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
if b.logger.V(2) { if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState)) b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
} }
var addrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling will
// change the order of endpoints but not touch the order of the addresses
// within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.ShuffleAddressListForTesting.(func(int, func(int, int)))(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for each
// of the endpoints, in order." - A61
for _, endpoint := range endpoints {
// "In the flattened list, interleave addresses from the two address
// families, as per RFC-8304 section 4." - A61
// TODO: support the above language.
addrs = append(addrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
addrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
addrs = append([]resolver.Address{}, addrs...)
rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
}
}
if b.subConn != nil { if b.subConn != nil {
b.cc.UpdateAddresses(b.subConn, addrs) b.cc.UpdateAddresses(b.subConn, addrs)
return nil return nil
@ -243,7 +281,3 @@ func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.subConn.Connect() i.subConn.Connect()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
} }
func init() {
balancer.Register(newPickfirstBuilder())
}

View File

@ -22,12 +22,12 @@
package roundrobin package roundrobin
import ( import (
"math/rand"
"sync/atomic" "sync/atomic"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcrand"
) )
// Name is the name of round_robin balancer. // Name is the name of round_robin balancer.
@ -60,7 +60,7 @@ func (*rrPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
// Start at a random index, as the same RR balancer rebuilds a new // Start at a random index, as the same RR balancer rebuilds a new
// picker when SubConn states change, and we don't want to apply excess // picker when SubConn states change, and we don't want to apply excess
// load to the first server in the list. // load to the first server in the list.
next: uint32(grpcrand.Intn(len(scs))), next: uint32(rand.Intn(len(scs))),
} }
} }

View File

@ -21,17 +21,19 @@ package grpc
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address))
// ccBalancerWrapper sits between the ClientConn and the Balancer. // ccBalancerWrapper sits between the ClientConn and the Balancer.
// //
// ccBalancerWrapper implements methods corresponding to the ones on the // ccBalancerWrapper implements methods corresponding to the ones on the
@ -66,7 +68,8 @@ type ccBalancerWrapper struct {
} }
// newCCBalancerWrapper creates a new balancer wrapper in idle state. The // newCCBalancerWrapper creates a new balancer wrapper in idle state. The
// underlying balancer is not created until the switchTo() method is invoked. // underlying balancer is not created until the updateClientConnState() method
// is invoked.
func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper { func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
ctx, cancel := context.WithCancel(cc.ctx) ctx, cancel := context.WithCancel(cc.ctx)
ccb := &ccBalancerWrapper{ ccb := &ccBalancerWrapper{
@ -77,8 +80,9 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
Dialer: cc.dopts.copts.Dialer, Dialer: cc.dopts.copts.Dialer,
Authority: cc.authority, Authority: cc.authority,
CustomUserAgent: cc.dopts.copts.UserAgent, CustomUserAgent: cc.dopts.copts.UserAgent,
ChannelzParentID: cc.channelzID, ChannelzParent: cc.channelz,
Target: cc.parsedTarget, Target: cc.parsedTarget,
MetricsRecorder: cc.metricsRecorderList,
}, },
serializer: grpcsync.NewCallbackSerializer(ctx), serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel, serializerCancel: cancel,
@ -92,27 +96,38 @@ func newCCBalancerWrapper(cc *ClientConn) *ccBalancerWrapper {
// it is safe to call into the balancer here. // it is safe to call into the balancer here.
func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error { func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) error {
errCh := make(chan error) errCh := make(chan error)
ok := ccb.serializer.Schedule(func(ctx context.Context) { uccs := func(ctx context.Context) {
defer close(errCh) defer close(errCh)
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
} }
name := gracefulswitch.ChildName(ccs.BalancerConfig)
if ccb.curBalancerName != name {
ccb.curBalancerName = name
channelz.Infof(logger, ccb.cc.channelz, "Channel switches to new LB policy %q", name)
}
err := ccb.balancer.UpdateClientConnState(*ccs) err := ccb.balancer.UpdateClientConnState(*ccs)
if logger.V(2) && err != nil { if logger.V(2) && err != nil {
logger.Infof("error from balancer.UpdateClientConnState: %v", err) logger.Infof("error from balancer.UpdateClientConnState: %v", err)
} }
errCh <- err errCh <- err
})
if !ok {
return nil
} }
onFailure := func() { close(errCh) }
// UpdateClientConnState can race with Close, and when the latter wins, the
// serializer is closed, and the attempt to schedule the callback will fail.
// It is acceptable to ignore this failure. But since we want to handle the
// state update in a blocking fashion (when we successfully schedule the
// callback), we have to use the ScheduleOr method and not the MaybeSchedule
// method on the serializer.
ccb.serializer.ScheduleOr(uccs, onFailure)
return <-errCh return <-errCh
} }
// resolverError is invoked by grpc to push a resolver error to the underlying // resolverError is invoked by grpc to push a resolver error to the underlying
// balancer. The call to the balancer is executed from the serializer. // balancer. The call to the balancer is executed from the serializer.
func (ccb *ccBalancerWrapper) resolverError(err error) { func (ccb *ccBalancerWrapper) resolverError(err error) {
ccb.serializer.Schedule(func(ctx context.Context) { ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
} }
@ -120,54 +135,6 @@ func (ccb *ccBalancerWrapper) resolverError(err error) {
}) })
} }
// switchTo is invoked by grpc to instruct the balancer wrapper to switch to the
// LB policy identified by name.
//
// ClientConn calls newCCBalancerWrapper() at creation time. Upon receipt of the
// first good update from the name resolver, it determines the LB policy to use
// and invokes the switchTo() method. Upon receipt of every subsequent update
// from the name resolver, it invokes this method.
//
// the ccBalancerWrapper keeps track of the current LB policy name, and skips
// the graceful balancer switching process if the name does not change.
func (ccb *ccBalancerWrapper) switchTo(name string) {
ccb.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil {
return
}
// TODO: Other languages use case-sensitive balancer registries. We should
// switch as well. See: https://github.com/grpc/grpc-go/issues/5288.
if strings.EqualFold(ccb.curBalancerName, name) {
return
}
ccb.buildLoadBalancingPolicy(name)
})
}
// buildLoadBalancingPolicy performs the following:
// - retrieve a balancer builder for the given name. Use the default LB
// policy, pick_first, if no LB policy with name is found in the registry.
// - instruct the gracefulswitch balancer to switch to the above builder. This
// will actually build the new balancer.
// - update the `curBalancerName` field
//
// Must be called from a serializer callback.
func (ccb *ccBalancerWrapper) buildLoadBalancingPolicy(name string) {
builder := balancer.Get(name)
if builder == nil {
channelz.Warningf(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q, since the specified LB policy %q was not registered", PickFirstBalancerName, name)
builder = newPickfirstBuilder()
} else {
channelz.Infof(logger, ccb.cc.channelzID, "Channel switches to new LB policy %q", name)
}
if err := ccb.balancer.SwitchTo(builder); err != nil {
channelz.Errorf(logger, ccb.cc.channelzID, "Channel failed to build new LB policy %q: %v", name, err)
return
}
ccb.curBalancerName = builder.Name()
}
// close initiates async shutdown of the wrapper. cc.mu must be held when // close initiates async shutdown of the wrapper. cc.mu must be held when
// calling this function. To determine the wrapper has finished shutting down, // calling this function. To determine the wrapper has finished shutting down,
// the channel should block on ccb.serializer.Done() without cc.mu held. // the channel should block on ccb.serializer.Done() without cc.mu held.
@ -175,8 +142,8 @@ func (ccb *ccBalancerWrapper) close() {
ccb.mu.Lock() ccb.mu.Lock()
ccb.closed = true ccb.closed = true
ccb.mu.Unlock() ccb.mu.Unlock()
channelz.Info(logger, ccb.cc.channelzID, "ccBalancerWrapper: closing") channelz.Info(logger, ccb.cc.channelz, "ccBalancerWrapper: closing")
ccb.serializer.Schedule(func(context.Context) { ccb.serializer.TrySchedule(func(context.Context) {
if ccb.balancer == nil { if ccb.balancer == nil {
return return
} }
@ -188,7 +155,7 @@ func (ccb *ccBalancerWrapper) close() {
// exitIdle invokes the balancer's exitIdle method in the serializer. // exitIdle invokes the balancer's exitIdle method in the serializer.
func (ccb *ccBalancerWrapper) exitIdle() { func (ccb *ccBalancerWrapper) exitIdle() {
ccb.serializer.Schedule(func(ctx context.Context) { ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccb.balancer == nil { if ctx.Err() != nil || ccb.balancer == nil {
return return
} }
@ -212,7 +179,7 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
} }
ac, err := ccb.cc.newAddrConnLocked(addrs, opts) ac, err := ccb.cc.newAddrConnLocked(addrs, opts)
if err != nil { if err != nil {
channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err) channelz.Warningf(logger, ccb.cc.channelz, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err return nil, err
} }
acbw := &acBalancerWrapper{ acbw := &acBalancerWrapper{
@ -241,6 +208,10 @@ func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resol
func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) { func (ccb *ccBalancerWrapper) UpdateState(s balancer.State) {
ccb.cc.mu.Lock() ccb.cc.mu.Lock()
defer ccb.cc.mu.Unlock() defer ccb.cc.mu.Unlock()
if ccb.cc.conns == nil {
// The CC has been closed; ignore this update.
return
}
ccb.mu.Lock() ccb.mu.Lock()
if ccb.closed { if ccb.closed {
@ -291,20 +262,34 @@ type acBalancerWrapper struct {
// updateState is invoked by grpc to push a subConn state update to the // updateState is invoked by grpc to push a subConn state update to the
// underlying balancer. // underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) { acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil { if ctx.Err() != nil || acbw.ccb.balancer == nil {
return return
} }
// Even though it is optional for balancers, gracefulswitch ensures // Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil. // opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed. // TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err}) scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
setConnectedAddress(&scs, curAddr)
}
acbw.stateListener(scs)
acbw.ac.mu.Lock()
defer acbw.ac.mu.Unlock()
if s == connectivity.Ready {
// When changing states to READY, reset stateReadyChan. Wait until
// after we notify the LB policy's listener(s) in order to prevent
// ac.getTransport() from unblocking before the LB policy starts
// tracking the subchannel as READY.
close(acbw.ac.stateReadyChan)
acbw.ac.stateReadyChan = make(chan struct{})
}
}) })
} }
func (acbw *acBalancerWrapper) String() string { func (acbw *acBalancerWrapper) String() string {
return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelzID.Int()) return fmt.Sprintf("SubConn(id:%d)", acbw.ac.channelz.ID)
} }
func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) {

View File

@ -18,8 +18,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.32.0 // protoc-gen-go v1.34.1
// protoc v4.25.2 // protoc v5.27.1
// source: grpc/binlog/v1/binarylog.proto // source: grpc/binlog/v1/binarylog.proto
package grpc_binarylog_v1 package grpc_binarylog_v1

View File

@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -31,14 +32,15 @@ import (
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base" "google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/idle" "google.golang.org/grpc/internal/idle"
"google.golang.org/grpc/internal/pretty"
iresolver "google.golang.org/grpc/internal/resolver" iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/stats"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -67,12 +69,14 @@ var (
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing. // errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing") errConnClosing = errors.New("grpc: the connection is closing")
// errConnIdling indicates the the connection is being closed as the channel // errConnIdling indicates the connection is being closed as the channel
// is moving to an idle mode due to inactivity. // is moving to an idle mode due to inactivity.
errConnIdling = errors.New("grpc: the connection is closing due to channel idleness") errConnIdling = errors.New("grpc: the connection is closing due to channel idleness")
// invalidDefaultServiceConfigErrPrefix is used to prefix the json parsing error for the default // invalidDefaultServiceConfigErrPrefix is used to prefix the json parsing error for the default
// service config. // service config.
invalidDefaultServiceConfigErrPrefix = "grpc: the provided default service config is invalid" invalidDefaultServiceConfigErrPrefix = "grpc: the provided default service config is invalid"
// PickFirstBalancerName is the name of the pick_first balancer.
PickFirstBalancerName = pickfirst.Name
) )
// The following errors are returned from Dial and DialContext // The following errors are returned from Dial and DialContext
@ -101,11 +105,6 @@ const (
defaultReadBufSize = 32 * 1024 defaultReadBufSize = 32 * 1024
) )
// Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
type defaultConfigSelector struct { type defaultConfigSelector struct {
sc *ServiceConfig sc *ServiceConfig
} }
@ -117,13 +116,23 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires
}, nil }, nil
} }
// newClient returns a new client in idle mode. // NewClient creates a new gRPC "channel" for the target URI provided. No I/O
func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) { // is performed. Use of the ClientConn for RPCs will automatically cause it to
// connect. Connect may be used to manually create a connection, but for most
// users this is unnecessary.
//
// The target name syntax is defined in
// https://github.com/grpc/grpc/blob/master/doc/naming.md. e.g. to use dns
// resolver, a "dns:///" prefix should be applied to the target.
//
// The DialOptions returned by WithBlock, WithTimeout,
// WithReturnConnectionError, and FailOnNonTempDialError are ignored by this
// function.
func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[*addrConn]struct{}), conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(), dopts: defaultDialOptions(),
czData: new(channelzData),
} }
cc.retryThrottler.Store((*retryThrottler)(nil)) cc.retryThrottler.Store((*retryThrottler)(nil))
@ -148,6 +157,16 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
for _, opt := range opts { for _, opt := range opts {
opt.apply(&cc.dopts) opt.apply(&cc.dopts)
} }
// Determine the resolver to use.
if err := cc.initParsedTargetAndResolverBuilder(); err != nil {
return nil, err
}
for _, opt := range globalPerTargetDialOptions {
opt.DialOptionForTarget(cc.parsedTarget.URL).apply(&cc.dopts)
}
chainUnaryClientInterceptors(cc) chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc) chainStreamClientInterceptors(cc)
@ -156,7 +175,7 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
} }
if cc.dopts.defaultServiceConfigRawJSON != nil { if cc.dopts.defaultServiceConfigRawJSON != nil {
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON) scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON, cc.dopts.maxCallAttempts)
if scpr.Err != nil { if scpr.Err != nil {
return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err) return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
} }
@ -164,66 +183,57 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
} }
cc.mkp = cc.dopts.copts.KeepaliveParams cc.mkp = cc.dopts.copts.KeepaliveParams
// Register ClientConn with channelz. if err = cc.initAuthority(); err != nil {
return nil, err
}
// Register ClientConn with channelz. Note that this is only done after
// channel creation cannot fail.
cc.channelzRegistration(target) cc.channelzRegistration(target)
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", cc.parsedTarget)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
// TODO: Ideally it should be impossible to error from this function after cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
// channelz registration. This will require removing some channelz logs
// from the following functions that can error. Errors can be returned to
// the user, and successful logs can be emitted here, after the checks have
// passed and channelz is subsequently registered.
// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
if err = cc.determineAuthority(); err != nil {
channelz.RemoveEntry(cc.channelzID)
return nil, err
}
cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelzID)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers) cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)
cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers)
cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc. cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc.
cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout) cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout)
return cc, nil return cc, nil
} }
// DialContext creates a client connection to the given target. By default, it's // Dial calls DialContext(context.Background(), target, opts...).
// a non-blocking dial (the function won't wait for connections to be
// established, and connecting happens in the background). To make it a blocking
// dial, use WithBlock() dial option.
// //
// In the non-blocking case, the ctx does not act against the connection. It // Deprecated: use NewClient instead. Will be supported throughout 1.x.
// only controls the setup steps. func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext calls NewClient and then exits idle mode. If WithBlock(true) is
// used, it calls Connect and WaitForStateChange until either the context
// expires or the state of the ClientConn is Ready.
// //
// In the blocking case, ctx can be used to cancel or expire the pending // One subtle difference between NewClient and Dial and DialContext is that the
// connection. Once this function returns, the cancellation and expiration of // former uses "dns" as the default name resolver, while the latter use
// ctx will be noop. Users should call ClientConn.Close to terminate all the // "passthrough" for backward compatibility. This distinction should not matter
// pending operations after this function returns. // to most users, but could matter to legacy users that specify a custom dialer
// and expect it to receive the target string directly.
// //
// The target name syntax is defined in // Deprecated: use NewClient instead. Will be supported throughout 1.x.
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc, err := newClient(target, opts...) // At the end of this method, we kick the channel out of idle, rather than
// waiting for the first rpc.
opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...)
cc, err := NewClient(target, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// We start the channel off in idle mode, but kick it out of idle now, // We start the channel off in idle mode, but kick it out of idle now,
// instead of waiting for the first RPC. Other gRPC implementations do wait // instead of waiting for the first RPC. This is the legacy behavior of
// for the first RPC to kick the channel out of idle. But doing so would be // Dial.
// a major behavior change for our users who are used to seeing the channel
// active after Dial.
//
// Taking this approach of kicking it out of idle at the end of this method
// allows us to share the code between channel creation and exiting idle
// mode. This will also make it easy for us to switch to starting the
// channel off in idle, i.e. by making newClient exported.
defer func() { defer func() {
if err != nil { if err != nil {
cc.Close() cc.Close()
@ -291,17 +301,17 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
// addTraceEvent is a helper method to add a trace event on the channel. If the // addTraceEvent is a helper method to add a trace event on the channel. If the
// channel is a nested one, the same event is also added on the parent channel. // channel is a nested one, the same event is also added on the parent channel.
func (cc *ClientConn) addTraceEvent(msg string) { func (cc *ClientConn) addTraceEvent(msg string) {
ted := &channelz.TraceEventDesc{ ted := &channelz.TraceEvent{
Desc: fmt.Sprintf("Channel %s", msg), Desc: fmt.Sprintf("Channel %s", msg),
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
} }
if cc.dopts.channelzParentID != nil { if cc.dopts.channelzParent != nil {
ted.Parent = &channelz.TraceEventDesc{ ted.Parent = &channelz.TraceEvent{
Desc: fmt.Sprintf("Nested channel(id:%d) %s", cc.channelzID.Int(), msg), Desc: fmt.Sprintf("Nested channel(id:%d) %s", cc.channelz.ID, msg),
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
} }
} }
channelz.AddTraceEvent(logger, cc.channelzID, 0, ted) channelz.AddTraceEvent(logger, cc.channelz, 0, ted)
} }
type idler ClientConn type idler ClientConn
@ -418,14 +428,15 @@ func (cc *ClientConn) validateTransportCredentials() error {
} }
// channelzRegistration registers the newly created ClientConn with channelz and // channelzRegistration registers the newly created ClientConn with channelz and
// stores the returned identifier in `cc.channelzID` and `cc.csMgr.channelzID`. // stores the returned identifier in `cc.channelz`. A channelz trace event is
// A channelz trace event is emitted for ClientConn creation. If the newly // emitted for ClientConn creation. If the newly created ClientConn is a nested
// created ClientConn is a nested one, i.e a valid parent ClientConn ID is // one, i.e a valid parent ClientConn ID is specified via a dial option, the
// specified via a dial option, the trace event is also added to the parent. // trace event is also added to the parent.
// //
// Doesn't grab cc.mu as this method is expected to be called only at Dial time. // Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) channelzRegistration(target string) { func (cc *ClientConn) channelzRegistration(target string) {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target) parentChannel, _ := cc.dopts.channelzParent.(*channelz.Channel)
cc.channelz = channelz.RegisterChannel(parentChannel, target)
cc.addTraceEvent("created") cc.addTraceEvent("created")
} }
@ -492,10 +503,10 @@ func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStr
} }
// newConnectivityStateManager creates an connectivityStateManager with // newConnectivityStateManager creates an connectivityStateManager with
// the specified id. // the specified channel.
func newConnectivityStateManager(ctx context.Context, id *channelz.Identifier) *connectivityStateManager { func newConnectivityStateManager(ctx context.Context, channel *channelz.Channel) *connectivityStateManager {
return &connectivityStateManager{ return &connectivityStateManager{
channelzID: id, channelz: channel,
pubSub: grpcsync.NewPubSub(ctx), pubSub: grpcsync.NewPubSub(ctx),
} }
} }
@ -510,7 +521,7 @@ type connectivityStateManager struct {
mu sync.Mutex mu sync.Mutex
state connectivity.State state connectivity.State
notifyChan chan struct{} notifyChan chan struct{}
channelzID *channelz.Identifier channelz *channelz.Channel
pubSub *grpcsync.PubSub pubSub *grpcsync.PubSub
} }
@ -527,9 +538,10 @@ func (csm *connectivityStateManager) updateState(state connectivity.State) {
return return
} }
csm.state = state csm.state = state
csm.channelz.ChannelMetrics.State.Store(&state)
csm.pubSub.Publish(state) csm.pubSub.Publish(state)
channelz.Infof(logger, csm.channelzID, "Channel Connectivity change to %v", state) channelz.Infof(logger, csm.channelz, "Channel Connectivity change to %v", state)
if csm.notifyChan != nil { if csm.notifyChan != nil {
// There are other goroutines waiting on this channel. // There are other goroutines waiting on this channel.
close(csm.notifyChan) close(csm.notifyChan)
@ -584,19 +596,19 @@ type ClientConn struct {
// The following are initialized at dial time, and are read-only after that. // The following are initialized at dial time, and are read-only after that.
target string // User's dial target. target string // User's dial target.
parsedTarget resolver.Target // See parseTargetAndFindResolver(). parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder().
authority string // See determineAuthority(). authority string // See initAuthority().
dopts dialOptions // Default and user specified dial options. dopts dialOptions // Default and user specified dial options.
channelzID *channelz.Identifier // Channelz identifier for the channel. channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See parseTargetAndFindResolver(). resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager idlenessMgr *idle.Manager
metricsRecorderList *stats.MetricsRecorderList
// The following provide their own synchronization, and therefore don't // The following provide their own synchronization, and therefore don't
// require cc.mu to be held to access them. // require cc.mu to be held to access them.
csMgr *connectivityStateManager csMgr *connectivityStateManager
pickerWrapper *pickerWrapper pickerWrapper *pickerWrapper
safeConfigSelector iresolver.SafeConfigSelector safeConfigSelector iresolver.SafeConfigSelector
czData *channelzData
retryThrottler atomic.Value // Updated from service config. retryThrottler atomic.Value // Updated from service config.
// mu protects the following fields. // mu protects the following fields.
@ -620,11 +632,6 @@ type ClientConn struct {
// WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or
// ctx expires. A true value is returned in former case and false in latter. // ctx expires. A true value is returned in former case and false in latter.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool { func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
ch := cc.csMgr.getNotifyChan() ch := cc.csMgr.getNotifyChan()
if cc.csMgr.getState() != sourceState { if cc.csMgr.getState() != sourceState {
@ -639,11 +646,6 @@ func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState connec
} }
// GetState returns the connectivity.State of ClientConn. // GetState returns the connectivity.State of ClientConn.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (cc *ClientConn) GetState() connectivity.State { func (cc *ClientConn) GetState() connectivity.State {
return cc.csMgr.getState() return cc.csMgr.getState()
} }
@ -690,7 +692,7 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
var emptyServiceConfig *ServiceConfig var emptyServiceConfig *ServiceConfig
func init() { func init() {
cfg := parseServiceConfig("{}") cfg := parseServiceConfig("{}", defaultMaxCallAttempts)
if cfg.Err != nil { if cfg.Err != nil {
panic(fmt.Sprintf("impossible error parsing empty service config: %v", cfg.Err)) panic(fmt.Sprintf("impossible error parsing empty service config: %v", cfg.Err))
} }
@ -707,15 +709,15 @@ func init() {
} }
} }
func (cc *ClientConn) maybeApplyDefaultServiceConfig(addrs []resolver.Address) { func (cc *ClientConn) maybeApplyDefaultServiceConfig() {
if cc.sc != nil { if cc.sc != nil {
cc.applyServiceConfigAndBalancer(cc.sc, nil, addrs) cc.applyServiceConfigAndBalancer(cc.sc, nil)
return return
} }
if cc.dopts.defaultServiceConfig != nil { if cc.dopts.defaultServiceConfig != nil {
cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, &defaultConfigSelector{cc.dopts.defaultServiceConfig}, addrs) cc.applyServiceConfigAndBalancer(cc.dopts.defaultServiceConfig, &defaultConfigSelector{cc.dopts.defaultServiceConfig})
} else { } else {
cc.applyServiceConfigAndBalancer(emptyServiceConfig, &defaultConfigSelector{emptyServiceConfig}, addrs) cc.applyServiceConfigAndBalancer(emptyServiceConfig, &defaultConfigSelector{emptyServiceConfig})
} }
} }
@ -733,7 +735,7 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
// May need to apply the initial service config in case the resolver // May need to apply the initial service config in case the resolver
// doesn't support service configs, or doesn't provide a service config // doesn't support service configs, or doesn't provide a service config
// with the new addresses. // with the new addresses.
cc.maybeApplyDefaultServiceConfig(nil) cc.maybeApplyDefaultServiceConfig()
cc.balancerWrapper.resolverError(err) cc.balancerWrapper.resolverError(err)
@ -744,10 +746,10 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
var ret error var ret error
if cc.dopts.disableServiceConfig { if cc.dopts.disableServiceConfig {
channelz.Infof(logger, cc.channelzID, "ignoring service config from resolver (%v) and applying the default because service config is disabled", s.ServiceConfig) channelz.Infof(logger, cc.channelz, "ignoring service config from resolver (%v) and applying the default because service config is disabled", s.ServiceConfig)
cc.maybeApplyDefaultServiceConfig(s.Addresses) cc.maybeApplyDefaultServiceConfig()
} else if s.ServiceConfig == nil { } else if s.ServiceConfig == nil {
cc.maybeApplyDefaultServiceConfig(s.Addresses) cc.maybeApplyDefaultServiceConfig()
// TODO: do we need to apply a failing LB policy if there is no // TODO: do we need to apply a failing LB policy if there is no
// default, per the error handling design? // default, per the error handling design?
} else { } else {
@ -755,12 +757,12 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
configSelector := iresolver.GetConfigSelector(s) configSelector := iresolver.GetConfigSelector(s)
if configSelector != nil { if configSelector != nil {
if len(s.ServiceConfig.Config.(*ServiceConfig).Methods) != 0 { if len(s.ServiceConfig.Config.(*ServiceConfig).Methods) != 0 {
channelz.Infof(logger, cc.channelzID, "method configs in service config will be ignored due to presence of config selector") channelz.Infof(logger, cc.channelz, "method configs in service config will be ignored due to presence of config selector")
} }
} else { } else {
configSelector = &defaultConfigSelector{sc} configSelector = &defaultConfigSelector{sc}
} }
cc.applyServiceConfigAndBalancer(sc, configSelector, s.Addresses) cc.applyServiceConfigAndBalancer(sc, configSelector)
} else { } else {
ret = balancer.ErrBadResolverState ret = balancer.ErrBadResolverState
if cc.sc == nil { if cc.sc == nil {
@ -775,7 +777,7 @@ func (cc *ClientConn) updateResolverStateAndUnlock(s resolver.State, err error)
var balCfg serviceconfig.LoadBalancingConfig var balCfg serviceconfig.LoadBalancingConfig
if cc.sc != nil && cc.sc.lbConfig != nil { if cc.sc != nil && cc.sc.lbConfig != nil {
balCfg = cc.sc.lbConfig.cfg balCfg = cc.sc.lbConfig
} }
bw := cc.balancerWrapper bw := cc.balancerWrapper
cc.mu.Unlock() cc.mu.Unlock()
@ -806,17 +808,11 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
cc.csMgr.updateState(connectivity.TransientFailure) cc.csMgr.updateState(connectivity.TransientFailure)
} }
// Makes a copy of the input addresses slice and clears out the balancer // Makes a copy of the input addresses slice. Addresses are passed during
// attributes field. Addresses are passed during subconn creation and address // subconn creation and address update operations.
// update operations. In both cases, we will clear the balancer attributes by func copyAddresses(in []resolver.Address) []resolver.Address {
// calling this function, and therefore we will be able to use the Equal method
// provided by the resolver.Address type for comparison.
func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address {
out := make([]resolver.Address, len(in)) out := make([]resolver.Address, len(in))
for i := range in { copy(out, in)
out[i] = in[i]
out[i].BalancerAttributes = nil
}
return out return out
} }
@ -831,25 +827,23 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
ac := &addrConn{ ac := &addrConn{
state: connectivity.Idle, state: connectivity.Idle,
cc: cc, cc: cc,
addrs: copyAddressesWithoutBalancerAttributes(addrs), addrs: copyAddresses(addrs),
scopts: opts, scopts: opts,
dopts: cc.dopts, dopts: cc.dopts,
czData: new(channelzData), channelz: channelz.RegisterSubChannel(cc.channelz, ""),
resetBackoff: make(chan struct{}), resetBackoff: make(chan struct{}),
stateChan: make(chan struct{}), stateReadyChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx) ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
// Start with our address set to the first address; this may be updated if
// we connect to different addresses.
ac.channelz.ChannelMetrics.Target.Store(&addrs[0].Addr)
var err error channelz.AddTraceEvent(logger, ac.channelz, 0, &channelz.TraceEvent{
ac.channelzID, err = channelz.RegisterSubChannel(ac, cc.channelzID, "")
if err != nil {
return nil, err
}
channelz.AddTraceEvent(logger, ac.channelzID, 0, &channelz.TraceEventDesc{
Desc: "Subchannel created", Desc: "Subchannel created",
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
Parent: &channelz.TraceEventDesc{ Parent: &channelz.TraceEvent{
Desc: fmt.Sprintf("Subchannel(id:%d) created", ac.channelzID.Int()), Desc: fmt.Sprintf("Subchannel(id:%d) created", ac.channelz.ID),
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
}, },
}) })
@ -872,38 +866,27 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) {
ac.tearDown(err) ac.tearDown(err)
} }
func (cc *ClientConn) channelzMetric() *channelz.ChannelInternalMetric {
return &channelz.ChannelInternalMetric{
State: cc.GetState(),
Target: cc.target,
CallsStarted: atomic.LoadInt64(&cc.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&cc.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&cc.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&cc.czData.lastCallStartedTime)),
}
}
// Target returns the target string of the ClientConn. // Target returns the target string of the ClientConn.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func (cc *ClientConn) Target() string { func (cc *ClientConn) Target() string {
return cc.target return cc.target
} }
// CanonicalTarget returns the canonical target string of the ClientConn.
func (cc *ClientConn) CanonicalTarget() string {
return cc.parsedTarget.String()
}
func (cc *ClientConn) incrCallsStarted() { func (cc *ClientConn) incrCallsStarted() {
atomic.AddInt64(&cc.czData.callsStarted, 1) cc.channelz.ChannelMetrics.CallsStarted.Add(1)
atomic.StoreInt64(&cc.czData.lastCallStartedTime, time.Now().UnixNano()) cc.channelz.ChannelMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
} }
func (cc *ClientConn) incrCallsSucceeded() { func (cc *ClientConn) incrCallsSucceeded() {
atomic.AddInt64(&cc.czData.callsSucceeded, 1) cc.channelz.ChannelMetrics.CallsSucceeded.Add(1)
} }
func (cc *ClientConn) incrCallsFailed() { func (cc *ClientConn) incrCallsFailed() {
atomic.AddInt64(&cc.czData.callsFailed, 1) cc.channelz.ChannelMetrics.CallsFailed.Add(1)
} }
// connect starts creating a transport. // connect starts creating a transport.
@ -925,32 +908,37 @@ func (ac *addrConn) connect() error {
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
ac.mu.Unlock()
ac.resetTransport() ac.resetTransportAndUnlock()
return nil return nil
} }
func equalAddresses(a, b []resolver.Address) bool { // equalAddressIgnoringBalAttributes returns true is a and b are considered equal.
if len(a) != len(b) { // This is different from the Equal method on the resolver.Address type which
return false // considers all fields to determine equality. Here, we only consider fields
} // that are meaningful to the subConn.
for i, v := range a { func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
if !v.Equal(b[i]) { return a.Addr == b.Addr && a.ServerName == b.ServerName &&
return false a.Attributes.Equal(b.Attributes) &&
} a.Metadata == b.Metadata
} }
return true
func equalAddressesIgnoringBalAttributes(a, b []resolver.Address) bool {
return slices.EqualFunc(a, b, func(a, b resolver.Address) bool { return equalAddressIgnoringBalAttributes(&a, &b) })
} }
// updateAddrs updates ac.addrs with the new addresses list and handles active // updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts. // connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) { func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.mu.Lock() addrs = copyAddresses(addrs)
channelz.Infof(logger, ac.channelzID, "addrConn: updateAddrs curAddr: %v, addrs: %v", pretty.ToJSON(ac.curAddr), pretty.ToJSON(addrs)) limit := len(addrs)
if limit > 5 {
limit = 5
}
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
addrs = copyAddressesWithoutBalancerAttributes(addrs) ac.mu.Lock()
if equalAddresses(ac.addrs, addrs) { if equalAddressesIgnoringBalAttributes(ac.addrs, addrs) {
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -969,7 +957,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
// Try to find the connected address. // Try to find the connected address.
for _, a := range addrs { for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a) a.ServerName = ac.cc.getServerName(a)
if a.Equal(ac.curAddr) { if equalAddressIgnoringBalAttributes(&a, &ac.curAddr) {
// We are connected to a valid address, so do nothing but // We are connected to a valid address, so do nothing but
// update the addresses. // update the addresses.
ac.mu.Unlock() ac.mu.Unlock()
@ -995,11 +983,9 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
ac.updateConnectivityState(connectivity.Idle, nil) ac.updateConnectivityState(connectivity.Idle, nil)
} }
ac.mu.Unlock()
// Since we were connecting/connected, we should start a new connection // Since we were connecting/connected, we should start a new connection
// attempt. // attempt.
go ac.resetTransport() go ac.resetTransportAndUnlock()
} }
// getServerName determines the serverName to be used in the connection // getServerName determines the serverName to be used in the connection
@ -1067,7 +1053,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method st
}) })
} }
func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector, addrs []resolver.Address) { func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) {
if sc == nil { if sc == nil {
// should never reach here. // should never reach here.
return return
@ -1088,17 +1074,6 @@ func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSel
} else { } else {
cc.retryThrottler.Store((*retryThrottler)(nil)) cc.retryThrottler.Store((*retryThrottler)(nil))
} }
var newBalancerName string
if cc.sc == nil || (cc.sc.lbConfig == nil && cc.sc.LB == nil) {
// No service config or no LB policy specified in config.
newBalancerName = PickFirstBalancerName
} else if cc.sc.lbConfig != nil {
newBalancerName = cc.sc.lbConfig.name
} else { // cc.sc.LB != nil
newBalancerName = *cc.sc.LB
}
cc.balancerWrapper.switchTo(newBalancerName)
} }
func (cc *ClientConn) resolveNow(o resolver.ResolveNowOptions) { func (cc *ClientConn) resolveNow(o resolver.ResolveNowOptions) {
@ -1174,7 +1149,7 @@ func (cc *ClientConn) Close() error {
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add // TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from being // trace reference to the entity being deleted, and thus prevent it from being
// deleted right away. // deleted right away.
channelz.RemoveEntry(cc.channelzID) channelz.RemoveEntry(cc.channelz.ID)
return nil return nil
} }
@ -1195,19 +1170,22 @@ type addrConn struct {
// is received, transport is closed, ac has been torn down). // is received, transport is closed, ac has been torn down).
transport transport.ClientTransport // The current transport. transport transport.ClientTransport // The current transport.
// This mutex is used on the RPC path, so its usage should be minimized as
// much as possible.
// TODO: Find a lock-free way to retrieve the transport and state from the
// addrConn.
mu sync.Mutex mu sync.Mutex
curAddr resolver.Address // The current address. curAddr resolver.Address // The current address.
addrs []resolver.Address // All addresses that the resolver resolved to. addrs []resolver.Address // All addresses that the resolver resolved to.
// Use updateConnectivityState for updating addrConn's connectivity state. // Use updateConnectivityState for updating addrConn's connectivity state.
state connectivity.State state connectivity.State
stateChan chan struct{} // closed and recreated on every state change. stateReadyChan chan struct{} // closed and recreated on every READY state change.
backoffIdx int // Needs to be stateful for resetConnectBackoff. backoffIdx int // Needs to be stateful for resetConnectBackoff.
resetBackoff chan struct{} resetBackoff chan struct{}
channelzID *channelz.Identifier channelz *channelz.SubChannel
czData *channelzData
} }
// Note: this requires a lock on ac.mu. // Note: this requires a lock on ac.mu.
@ -1215,16 +1193,14 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s { if ac.state == s {
return return
} }
// When changing states, reset the state change channel.
close(ac.stateChan)
ac.stateChan = make(chan struct{})
ac.state = s ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil { if lastErr == nil {
channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v", s) channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v", s)
} else { } else {
channelz.Infof(logger, ac.channelzID, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
} }
ac.acbw.updateState(s, lastErr) ac.acbw.updateState(s, ac.curAddr, lastErr)
} }
// adjustParams updates parameters used to create transports upon // adjustParams updates parameters used to create transports upon
@ -1241,8 +1217,10 @@ func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
} }
} }
func (ac *addrConn) resetTransport() { // resetTransportAndUnlock unconditionally connects the addrConn.
ac.mu.Lock() //
// ac.mu must be held by the caller, and this function will guarantee it is released.
func (ac *addrConn) resetTransportAndUnlock() {
acCtx := ac.ctx acCtx := ac.ctx
if acCtx.Err() != nil { if acCtx.Err() != nil {
ac.mu.Unlock() ac.mu.Unlock()
@ -1320,6 +1298,7 @@ func (ac *addrConn) resetTransport() {
func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error { func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, connectDeadline time.Time) error {
var firstConnErr error var firstConnErr error
for _, addr := range addrs { for _, addr := range addrs {
ac.channelz.ChannelMetrics.Target.Store(&addr.Addr)
if ctx.Err() != nil { if ctx.Err() != nil {
return errConnClosing return errConnClosing
} }
@ -1335,7 +1314,7 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
} }
ac.mu.Unlock() ac.mu.Unlock()
channelz.Infof(logger, ac.channelzID, "Subchannel picks a new address %q to connect", addr.Addr) channelz.Infof(logger, ac.channelz, "Subchannel picks a new address %q to connect", addr.Addr)
err := ac.createTransport(ctx, addr, copts, connectDeadline) err := ac.createTransport(ctx, addr, copts, connectDeadline)
if err == nil { if err == nil {
@ -1388,7 +1367,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
connectCtx, cancel := context.WithDeadline(ctx, connectDeadline) connectCtx, cancel := context.WithDeadline(ctx, connectDeadline)
defer cancel() defer cancel()
copts.ChannelzParentID = ac.channelzID copts.ChannelzParent = ac.channelz
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose) newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose)
if err != nil { if err != nil {
@ -1397,7 +1376,7 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
} }
// newTr is either nil, or closed. // newTr is either nil, or closed.
hcancel() hcancel()
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err) channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
return err return err
} }
@ -1469,7 +1448,7 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
// The health package is not imported to set health check function. // The health package is not imported to set health check function.
// //
// TODO: add a link to the health check doc in the error message. // TODO: add a link to the health check doc in the error message.
channelz.Error(logger, ac.channelzID, "Health check is requested but health check function is not set.") channelz.Error(logger, ac.channelz, "Health check is requested but health check function is not set.")
return return
} }
@ -1499,9 +1478,9 @@ func (ac *addrConn) startHealthCheck(ctx context.Context) {
err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName) err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if err != nil { if err != nil {
if status.Code(err) == codes.Unimplemented { if status.Code(err) == codes.Unimplemented {
channelz.Error(logger, ac.channelzID, "Subchannel health check is unimplemented at server side, thus health check is disabled") channelz.Error(logger, ac.channelz, "Subchannel health check is unimplemented at server side, thus health check is disabled")
} else { } else {
channelz.Errorf(logger, ac.channelzID, "Health checking failed: %v", err) channelz.Errorf(logger, ac.channelz, "Health checking failed: %v", err)
} }
} }
}() }()
@ -1531,7 +1510,7 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) { func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil { for ctx.Err() == nil {
ac.mu.Lock() ac.mu.Lock()
t, state, sc := ac.transport, ac.state, ac.stateChan t, state, sc := ac.transport, ac.state, ac.stateReadyChan
ac.mu.Unlock() ac.mu.Unlock()
if state == connectivity.Ready { if state == connectivity.Ready {
return t, nil return t, nil
@ -1566,18 +1545,18 @@ func (ac *addrConn) tearDown(err error) {
ac.cancel() ac.cancel()
ac.curAddr = resolver.Address{} ac.curAddr = resolver.Address{}
channelz.AddTraceEvent(logger, ac.channelzID, 0, &channelz.TraceEventDesc{ channelz.AddTraceEvent(logger, ac.channelz, 0, &channelz.TraceEvent{
Desc: "Subchannel deleted", Desc: "Subchannel deleted",
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
Parent: &channelz.TraceEventDesc{ Parent: &channelz.TraceEvent{
Desc: fmt.Sprintf("Subchannel(id:%d) deleted", ac.channelzID.Int()), Desc: fmt.Sprintf("Subchannel(id:%d) deleted", ac.channelz.ID),
Severity: channelz.CtInfo, Severity: channelz.CtInfo,
}, },
}) })
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add // TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from // trace reference to the entity being deleted, and thus prevent it from
// being deleted right away. // being deleted right away.
channelz.RemoveEntry(ac.channelzID) channelz.RemoveEntry(ac.channelz.ID)
ac.mu.Unlock() ac.mu.Unlock()
// We have to release the lock before the call to GracefulClose/Close here // We have to release the lock before the call to GracefulClose/Close here
@ -1594,7 +1573,7 @@ func (ac *addrConn) tearDown(err error) {
} else { } else {
// Hard close the transport when the channel is entering idle or is // Hard close the transport when the channel is entering idle or is
// being shutdown. In the case where the channel is being shutdown, // being shutdown. In the case where the channel is being shutdown,
// closing of transports is also taken care of by cancelation of cc.ctx. // closing of transports is also taken care of by cancellation of cc.ctx.
// But in the case where the channel is entering idle, we need to // But in the case where the channel is entering idle, we need to
// explicitly close the transports here. Instead of distinguishing // explicitly close the transports here. Instead of distinguishing
// between these two cases, it is simpler to close the transport // between these two cases, it is simpler to close the transport
@ -1604,39 +1583,6 @@ func (ac *addrConn) tearDown(err error) {
} }
} }
func (ac *addrConn) getState() connectivity.State {
ac.mu.Lock()
defer ac.mu.Unlock()
return ac.state
}
func (ac *addrConn) ChannelzMetric() *channelz.ChannelInternalMetric {
ac.mu.Lock()
addr := ac.curAddr.Addr
ac.mu.Unlock()
return &channelz.ChannelInternalMetric{
State: ac.getState(),
Target: addr,
CallsStarted: atomic.LoadInt64(&ac.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&ac.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&ac.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&ac.czData.lastCallStartedTime)),
}
}
func (ac *addrConn) incrCallsStarted() {
atomic.AddInt64(&ac.czData.callsStarted, 1)
atomic.StoreInt64(&ac.czData.lastCallStartedTime, time.Now().UnixNano())
}
func (ac *addrConn) incrCallsSucceeded() {
atomic.AddInt64(&ac.czData.callsSucceeded, 1)
}
func (ac *addrConn) incrCallsFailed() {
atomic.AddInt64(&ac.czData.callsFailed, 1)
}
type retryThrottler struct { type retryThrottler struct {
max float64 max float64
thresh float64 thresh float64
@ -1674,12 +1620,17 @@ func (rt *retryThrottler) successfulRPC() {
} }
} }
type channelzChannel struct { func (ac *addrConn) incrCallsStarted() {
cc *ClientConn ac.channelz.ChannelMetrics.CallsStarted.Add(1)
ac.channelz.ChannelMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
} }
func (c *channelzChannel) ChannelzMetric() *channelz.ChannelInternalMetric { func (ac *addrConn) incrCallsSucceeded() {
return c.cc.channelzMetric() ac.channelz.ChannelMetrics.CallsSucceeded.Add(1)
}
func (ac *addrConn) incrCallsFailed() {
ac.channelz.ChannelMetrics.CallsFailed.Add(1)
} }
// ErrClientConnTimeout indicates that the ClientConn cannot establish the // ErrClientConnTimeout indicates that the ClientConn cannot establish the
@ -1713,22 +1664,19 @@ func (cc *ClientConn) connectionError() error {
return cc.lastConnectionError return cc.lastConnectionError
} }
// parseTargetAndFindResolver parses the user's dial target and stores the // initParsedTargetAndResolverBuilder parses the user's dial target and stores
// parsed target in `cc.parsedTarget`. // the parsed target in `cc.parsedTarget`.
// //
// The resolver to use is determined based on the scheme in the parsed target // The resolver to use is determined based on the scheme in the parsed target
// and the same is stored in `cc.resolverBuilder`. // and the same is stored in `cc.resolverBuilder`.
// //
// Doesn't grab cc.mu as this method is expected to be called only at Dial time. // Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) parseTargetAndFindResolver() error { func (cc *ClientConn) initParsedTargetAndResolverBuilder() error {
channelz.Infof(logger, cc.channelzID, "original dial target is: %q", cc.target) logger.Infof("original dial target is: %q", cc.target)
var rb resolver.Builder var rb resolver.Builder
parsedTarget, err := parseTarget(cc.target) parsedTarget, err := parseTarget(cc.target)
if err != nil { if err == nil {
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", cc.target, err)
} else {
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %#v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme) rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb != nil { if rb != nil {
cc.parsedTarget = parsedTarget cc.parsedTarget = parsedTarget
@ -1740,17 +1688,19 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
// We are here because the user's dial target did not contain a scheme or // We are here because the user's dial target did not contain a scheme or
// specified an unregistered scheme. We should fallback to the default // specified an unregistered scheme. We should fallback to the default
// scheme, except when a custom dialer is specified in which case, we should // scheme, except when a custom dialer is specified in which case, we should
// always use passthrough scheme. // always use passthrough scheme. For either case, we need to respect any overridden
defScheme := resolver.GetDefaultScheme() // global defaults set by the user.
channelz.Infof(logger, cc.channelzID, "fallback to scheme %q", defScheme) defScheme := cc.dopts.defaultScheme
if internal.UserSetDefaultScheme {
defScheme = resolver.GetDefaultScheme()
}
canonicalTarget := defScheme + ":///" + cc.target canonicalTarget := defScheme + ":///" + cc.target
parsedTarget, err = parseTarget(canonicalTarget) parsedTarget, err = parseTarget(canonicalTarget)
if err != nil { if err != nil {
channelz.Infof(logger, cc.channelzID, "dial target %q parse failed: %v", canonicalTarget, err)
return err return err
} }
channelz.Infof(logger, cc.channelzID, "parsed dial target is: %+v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme) rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb == nil { if rb == nil {
return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme) return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme)
@ -1772,6 +1722,8 @@ func parseTarget(target string) (resolver.Target, error) {
return resolver.Target{URL: *u}, nil return resolver.Target{URL: *u}, nil
} }
// encodeAuthority escapes the authority string based on valid chars defined in
// https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.
func encodeAuthority(authority string) string { func encodeAuthority(authority string) string {
const upperhex = "0123456789ABCDEF" const upperhex = "0123456789ABCDEF"
@ -1788,7 +1740,7 @@ func encodeAuthority(authority string) string {
return false return false
case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // Subdelim characters case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // Subdelim characters
return false return false
case ':', '[', ']', '@': // Authority related delimeters case ':', '[', ']', '@': // Authority related delimiters
return false return false
} }
// Everything else must be escaped. // Everything else must be escaped.
@ -1838,7 +1790,7 @@ func encodeAuthority(authority string) string {
// credentials do not match the authority configured through the dial option. // credentials do not match the authority configured through the dial option.
// //
// Doesn't grab cc.mu as this method is expected to be called only at Dial time. // Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) determineAuthority() error { func (cc *ClientConn) initAuthority() error {
dopts := cc.dopts dopts := cc.dopts
// Historically, we had two options for users to specify the serverName or // Historically, we had two options for users to specify the serverName or
// authority for a channel. One was through the transport credentials // authority for a channel. One was through the transport credentials
@ -1871,6 +1823,5 @@ func (cc *ClientConn) determineAuthority() error {
} else { } else {
cc.authority = encodeAuthority(endpoint) cc.authority = encodeAuthority(endpoint)
} }
channelz.Infof(logger, cc.channelzID, "Channel authority set to %q", cc.authority)
return nil return nil
} }

View File

@ -21,18 +21,73 @@ package grpc
import ( import (
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto" _ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto"
"google.golang.org/grpc/mem"
) )
// baseCodec contains the functionality of both Codec and encoding.Codec, but // baseCodec captures the new encoding.CodecV2 interface without the Name
// omits the name/string, which vary between the two and are not needed for // function, allowing it to be implemented by older Codec and encoding.Codec
// anything besides the registry in the encoding package. // implementations. The omitted Name function is only needed for the register in
// the encoding package and is not part of the core functionality.
type baseCodec interface { type baseCodec interface {
Marshal(v any) ([]byte, error) Marshal(v any) (mem.BufferSlice, error)
Unmarshal(data []byte, v any) error Unmarshal(data mem.BufferSlice, v any) error
} }
var _ baseCodec = Codec(nil) // getCodec returns an encoding.CodecV2 for the codec of the given name (if
var _ baseCodec = encoding.Codec(nil) // registered). Initially checks the V2 registry with encoding.GetCodecV2 and
// returns the V2 codec if it is registered. Otherwise, it checks the V1 registry
// with encoding.GetCodec and if it is registered wraps it with newCodecV1Bridge
// to turn it into an encoding.CodecV2. Returns nil otherwise.
func getCodec(name string) encoding.CodecV2 {
if codecV1 := encoding.GetCodec(name); codecV1 != nil {
return newCodecV1Bridge(codecV1)
}
return encoding.GetCodecV2(name)
}
func newCodecV0Bridge(c Codec) baseCodec {
return codecV0Bridge{codec: c}
}
func newCodecV1Bridge(c encoding.Codec) encoding.CodecV2 {
return codecV1Bridge{
codecV0Bridge: codecV0Bridge{codec: c},
name: c.Name(),
}
}
var _ baseCodec = codecV0Bridge{}
type codecV0Bridge struct {
codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
}
func (c codecV0Bridge) Marshal(v any) (mem.BufferSlice, error) {
data, err := c.codec.Marshal(v)
if err != nil {
return nil, err
}
return mem.BufferSlice{mem.NewBuffer(&data, nil)}, nil
}
func (c codecV0Bridge) Unmarshal(data mem.BufferSlice, v any) (err error) {
return c.codec.Unmarshal(data.Materialize(), v)
}
var _ encoding.CodecV2 = codecV1Bridge{}
type codecV1Bridge struct {
codecV0Bridge
name string
}
func (c codecV1Bridge) Name() string {
return c.name
}
// Codec defines the interface gRPC uses to encode and decode messages. // Codec defines the interface gRPC uses to encode and decode messages.
// Note that implementations of this interface must be thread safe; // Note that implementations of this interface must be thread safe;

View File

@ -1,17 +0,0 @@
#!/usr/bin/env bash
# This script serves as an example to demonstrate how to generate the gRPC-Go
# interface and the related messages from .proto file.
#
# It assumes the installation of i) Google proto buffer compiler at
# https://github.com/google/protobuf (after v2.6.1) and ii) the Go codegen
# plugin at https://github.com/golang/protobuf (after 2015-02-20). If you have
# not, please install them first.
#
# We recommend running this script at $GOPATH/src.
#
# If this is not what you need, feel free to make your own scripts. Again, this
# script is for demonstration purpose.
#
proto=$1
protoc --go_out=plugins=grpc:. $proto

View File

@ -235,7 +235,7 @@ func (c *Code) UnmarshalJSON(b []byte) error {
if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil { if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil {
if ci >= _maxCode { if ci >= _maxCode {
return fmt.Errorf("invalid code: %q", ci) return fmt.Errorf("invalid code: %d", ci)
} }
*c = Code(ci) *c = Code(ci)

View File

@ -28,9 +28,9 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/attributes" "google.golang.org/grpc/attributes"
icredentials "google.golang.org/grpc/internal/credentials" icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/protobuf/proto"
) )
// PerRPCCredentials defines the common interface for the credentials which need to // PerRPCCredentials defines the common interface for the credentials which need to
@ -237,7 +237,7 @@ func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
} }
// CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one. // CheckSecurityLevel checks if a connection's security level is greater than or equal to the specified one.
// It returns success if 1) the condition is satisified or 2) AuthInfo struct does not implement GetCommonAuthInfo() method // It returns success if 1) the condition is satisfied or 2) AuthInfo struct does not implement GetCommonAuthInfo() method
// or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility. // or 3) CommonAuthInfo.SecurityLevel has an invalid zero value. For 2) and 3), it is for the purpose of backward-compatibility.
// //
// This API is experimental. // This API is experimental.

View File

@ -27,9 +27,13 @@ import (
"net/url" "net/url"
"os" "os"
"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials" credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
) )
var logger = grpclog.Component("credentials")
// TLSInfo contains the auth information for a TLS authenticated connection. // TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface. // It implements the AuthInfo interface.
type TLSInfo struct { type TLSInfo struct {
@ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
conn.Close() conn.Close()
return nil, nil, ctx.Err() return nil, nil, ctx.Err()
} }
// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
}
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
}
tlsInfo := TLSInfo{ tlsInfo := TLSInfo{
State: conn.ConnectionState(), State: conn.ConnectionState(),
CommonAuthInfo: CommonAuthInfo{ CommonAuthInfo: CommonAuthInfo{
@ -131,8 +151,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
conn.Close() conn.Close()
return nil, nil, err return nil, nil, err
} }
cs := conn.ConnectionState()
// The negotiated application protocol can be empty only if the client doesn't
// support ALPN. In such cases, we can close the connection since ALPN is required
// for using HTTP/2 over TLS.
if cs.NegotiatedProtocol == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
} else if logger.V(2) {
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
}
}
tlsInfo := TLSInfo{ tlsInfo := TLSInfo{
State: conn.ConnectionState(), State: cs,
CommonAuthInfo: CommonAuthInfo{ CommonAuthInfo: CommonAuthInfo{
SecurityLevel: PrivacyAndIntegrity, SecurityLevel: PrivacyAndIntegrity,
}, },

View File

@ -21,6 +21,7 @@ package grpc
import ( import (
"context" "context"
"net" "net"
"net/url"
"time" "time"
"google.golang.org/grpc/backoff" "google.golang.org/grpc/backoff"
@ -32,10 +33,16 @@ import (
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
) )
const (
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#limits-on-retries-and-hedges
defaultMaxCallAttempts = 5
)
func init() { func init() {
internal.AddGlobalDialOptions = func(opt ...DialOption) { internal.AddGlobalDialOptions = func(opt ...DialOption) {
globalDialOptions = append(globalDialOptions, opt...) globalDialOptions = append(globalDialOptions, opt...)
@ -43,10 +50,18 @@ func init() {
internal.ClearGlobalDialOptions = func() { internal.ClearGlobalDialOptions = func() {
globalDialOptions = nil globalDialOptions = nil
} }
internal.AddGlobalPerTargetDialOptions = func(opt any) {
if ptdo, ok := opt.(perTargetDialOption); ok {
globalPerTargetDialOptions = append(globalPerTargetDialOptions, ptdo)
}
}
internal.ClearGlobalPerTargetDialOptions = func() {
globalPerTargetDialOptions = nil
}
internal.WithBinaryLogger = withBinaryLogger internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
internal.WithRecvBufferPool = withRecvBufferPool internal.WithBufferPool = withBufferPool
} }
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
@ -68,7 +83,7 @@ type dialOptions struct {
binaryLogger binarylog.Logger binaryLogger binarylog.Logger
copts transport.ConnectOptions copts transport.ConnectOptions
callOptions []CallOption callOptions []CallOption
channelzParentID *channelz.Identifier channelzParent channelz.Identifier
disableServiceConfig bool disableServiceConfig bool
disableRetry bool disableRetry bool
disableHealthCheck bool disableHealthCheck bool
@ -78,7 +93,8 @@ type dialOptions struct {
defaultServiceConfigRawJSON *string defaultServiceConfigRawJSON *string
resolvers []resolver.Builder resolvers []resolver.Builder
idleTimeout time.Duration idleTimeout time.Duration
recvBufferPool SharedBufferPool defaultScheme string
maxCallAttempts int
} }
// DialOption configures how we set up the connection. // DialOption configures how we set up the connection.
@ -88,6 +104,19 @@ type DialOption interface {
var globalDialOptions []DialOption var globalDialOptions []DialOption
// perTargetDialOption takes a parsed target and returns a dial option to apply.
//
// This gets called after NewClient() parses the target, and allows per target
// configuration set through a returned DialOption. The DialOption will not take
// effect if specifies a resolver builder, as that Dial Option is factored in
// while parsing target.
type perTargetDialOption interface {
// DialOption returns a Dial Option to apply.
DialOptionForTarget(parsedTarget url.URL) DialOption
}
var globalPerTargetDialOptions []perTargetDialOption
// EmptyDialOption does not alter the dial configuration. It can be embedded in // EmptyDialOption does not alter the dial configuration. It can be embedded in
// another structure to build custom dial options. // another structure to build custom dial options.
// //
@ -154,9 +183,7 @@ func WithSharedWriteBuffer(val bool) DialOption {
} }
// WithWriteBufferSize determines how much data can be batched before doing a // WithWriteBufferSize determines how much data can be batched before doing a
// write on the wire. The corresponding memory allocation for this buffer will // write on the wire. The default value for this buffer is 32KB.
// be twice the size to keep syscalls low. The default value for this buffer is
// 32KB.
// //
// Zero or negative values will disable the write buffer such that each write // Zero or negative values will disable the write buffer such that each write
// will be on underlying connection. Note: A Send call may not directly // will be on underlying connection. Note: A Send call may not directly
@ -301,6 +328,9 @@ func withBackoff(bs internalbackoff.Strategy) DialOption {
// //
// Use of this feature is not recommended. For more information, please see: // Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md // https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
//
// Deprecated: this DialOption is not supported by NewClient.
// Will be supported throughout 1.x.
func WithBlock() DialOption { func WithBlock() DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.block = true o.block = true
@ -315,10 +345,8 @@ func WithBlock() DialOption {
// Use of this feature is not recommended. For more information, please see: // Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md // https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
// //
// # Experimental // Deprecated: this DialOption is not supported by NewClient.
// // Will be supported throughout 1.x.
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithReturnConnectionError() DialOption { func WithReturnConnectionError() DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.block = true o.block = true
@ -388,8 +416,8 @@ func WithCredentialsBundle(b credentials.Bundle) DialOption {
// WithTimeout returns a DialOption that configures a timeout for dialing a // WithTimeout returns a DialOption that configures a timeout for dialing a
// ClientConn initially. This is valid if and only if WithBlock() is present. // ClientConn initially. This is valid if and only if WithBlock() is present.
// //
// Deprecated: use DialContext instead of Dial and context.WithTimeout // Deprecated: this DialOption is not supported by NewClient.
// instead. Will be supported throughout 1.x. // Will be supported throughout 1.x.
func WithTimeout(d time.Duration) DialOption { func WithTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.timeout = d o.timeout = d
@ -471,9 +499,8 @@ func withBinaryLogger(bl binarylog.Logger) DialOption {
// Use of this feature is not recommended. For more information, please see: // Use of this feature is not recommended. For more information, please see:
// https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md // https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
// //
// # Experimental // Deprecated: this DialOption is not supported by NewClient.
// // This API may be changed or removed in a
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
func FailOnNonTempDialError(f bool) DialOption { func FailOnNonTempDialError(f bool) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
@ -555,9 +582,9 @@ func WithAuthority(a string) DialOption {
// //
// Notice: This API is EXPERIMENTAL and may be changed or removed in a // Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
func WithChannelzParentID(id *channelz.Identifier) DialOption { func WithChannelzParentID(c channelz.Identifier) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.channelzParentID = id o.channelzParent = c
}) })
} }
@ -602,12 +629,22 @@ func WithDisableRetry() DialOption {
}) })
} }
// MaxHeaderListSizeDialOption is a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept.
type MaxHeaderListSizeDialOption struct {
MaxHeaderListSize uint32
}
func (o MaxHeaderListSizeDialOption) apply(do *dialOptions) {
do.copts.MaxHeaderListSize = &o.MaxHeaderListSize
}
// WithMaxHeaderListSize returns a DialOption that specifies the maximum // WithMaxHeaderListSize returns a DialOption that specifies the maximum
// (uncompressed) size of header list that the client is prepared to accept. // (uncompressed) size of header list that the client is prepared to accept.
func WithMaxHeaderListSize(s uint32) DialOption { func WithMaxHeaderListSize(s uint32) DialOption {
return newFuncDialOption(func(o *dialOptions) { return MaxHeaderListSizeDialOption{
o.copts.MaxHeaderListSize = &s MaxHeaderListSize: s,
}) }
} }
// WithDisableHealthCheck disables the LB channel health checking for all // WithDisableHealthCheck disables the LB channel health checking for all
@ -640,15 +677,17 @@ func defaultDialOptions() dialOptions {
WriteBufferSize: defaultWriteBufSize, WriteBufferSize: defaultWriteBufSize,
UseProxy: true, UseProxy: true,
UserAgent: grpcUA, UserAgent: grpcUA,
BufferPool: mem.DefaultBufferPool(),
}, },
bs: internalbackoff.DefaultExponential, bs: internalbackoff.DefaultExponential,
healthCheckFunc: internal.HealthCheckFunc, healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute, idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{}, defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
} }
} }
// withGetMinConnectDeadline specifies the function that clientconn uses to // withMinConnectDeadline specifies the function that clientconn uses to
// get minConnectDeadline. This can be used to make connection attempts happen // get minConnectDeadline. This can be used to make connection attempts happen
// faster/slower. // faster/slower.
// //
@ -659,6 +698,14 @@ func withMinConnectDeadline(f func() time.Duration) DialOption {
}) })
} }
// withDefaultScheme is used to allow Dial to use "passthrough" as the default
// name resolver, while NewClient uses "dns" otherwise.
func withDefaultScheme(s string) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.defaultScheme = s
})
}
// WithResolvers allows a list of resolver implementations to be registered // WithResolvers allows a list of resolver implementations to be registered
// locally with the ClientConn without needing to be globally registered via // locally with the ClientConn without needing to be globally registered via
// resolver.Register. They will be matched against the scheme used for the // resolver.Register. They will be matched against the scheme used for the
@ -694,25 +741,25 @@ func WithIdleTimeout(d time.Duration) DialOption {
}) })
} }
// WithRecvBufferPool returns a DialOption that configures the ClientConn // WithMaxCallAttempts returns a DialOption that configures the maximum number
// to use the provided shared buffer pool for parsing incoming messages. Depending // of attempts per call (including retries and hedging) using the channel.
// on the application's workload, this could result in reduced memory allocation. // Service owners may specify a higher value for these parameters, but higher
// values will be treated as equal to the maximum value by the client
// implementation. This mitigates security concerns related to the service
// config being transferred to the client via DNS.
// //
// If you are unsure about how to implement a memory pool but want to utilize one, // A value of 5 will be used if this dial option is not set or n < 2.
// begin with grpc.NewSharedBufferPool. func WithMaxCallAttempts(n int) DialOption {
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// v1.60.0 or later.
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return withRecvBufferPool(bufferPool)
}
func withRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool if n < 2 {
n = defaultMaxCallAttempts
}
o.maxCallAttempts = n
})
}
func withBufferPool(bufferPool mem.BufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.BufferPool = bufferPool
}) })
} }

View File

@ -16,7 +16,7 @@
* *
*/ */
//go:generate ./regenerate.sh //go:generate ./scripts/regenerate.sh
/* /*
Package grpc implements an RPC system called gRPC. Package grpc implements an RPC system called gRPC.

View File

@ -94,7 +94,7 @@ type Codec interface {
Name() string Name() string
} }
var registeredCodecs = make(map[string]Codec) var registeredCodecs = make(map[string]any)
// RegisterCodec registers the provided Codec for use with all gRPC clients and // RegisterCodec registers the provided Codec for use with all gRPC clients and
// servers. // servers.
@ -126,5 +126,6 @@ func RegisterCodec(codec Codec) {
// //
// The content-subtype is expected to be lowercase. // The content-subtype is expected to be lowercase.
func GetCodec(contentSubtype string) Codec { func GetCodec(contentSubtype string) Codec {
return registeredCodecs[contentSubtype] c, _ := registeredCodecs[contentSubtype].(Codec)
return c
} }

81
vendor/google.golang.org/grpc/encoding/encoding_v2.go generated vendored Normal file
View File

@ -0,0 +1,81 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package encoding
import (
"strings"
"google.golang.org/grpc/mem"
)
// CodecV2 defines the interface gRPC uses to encode and decode messages. Note
// that implementations of this interface must be thread safe; a CodecV2's
// methods can be called from concurrent goroutines.
type CodecV2 interface {
// Marshal returns the wire format of v. The buffers in the returned
// [mem.BufferSlice] must have at least one reference each, which will be freed
// by gRPC when they are no longer needed.
Marshal(v any) (out mem.BufferSlice, err error)
// Unmarshal parses the wire format into v. Note that data will be freed as soon
// as this function returns. If the codec wishes to guarantee access to the data
// after this function, it must take its own reference that it frees when it is
// no longer needed.
Unmarshal(data mem.BufferSlice, v any) error
// Name returns the name of the Codec implementation. The returned string
// will be used as part of content type in transmission. The result must be
// static; the result cannot change between calls.
Name() string
}
// RegisterCodecV2 registers the provided CodecV2 for use with all gRPC clients and
// servers.
//
// The CodecV2 will be stored and looked up by result of its Name() method, which
// should match the content-subtype of the encoding handled by the CodecV2. This
// is case-insensitive, and is stored and looked up as lowercase. If the
// result of calling Name() is an empty string, RegisterCodecV2 will panic. See
// Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If both a Codec and CodecV2 are registered with the same name, the CodecV2
// will be used.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Codecs are
// registered with the same name, the one registered last will take effect.
func RegisterCodecV2(codec CodecV2) {
if codec == nil {
panic("cannot register a nil CodecV2")
}
if codec.Name() == "" {
panic("cannot register CodecV2 with empty string result for Name()")
}
contentSubtype := strings.ToLower(codec.Name())
registeredCodecs[contentSubtype] = codec
}
// GetCodecV2 gets a registered CodecV2 by content-subtype, or nil if no CodecV2 is
// registered for the content-subtype.
//
// The content-subtype is expected to be lowercase.
func GetCodecV2(contentSubtype string) CodecV2 {
c, _ := registeredCodecs[contentSubtype].(CodecV2)
return c
}

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2018 gRPC authors. * Copyright 2024 gRPC authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import (
"fmt" "fmt"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt" "google.golang.org/protobuf/protoadapt"
) )
@ -32,28 +33,51 @@ import (
const Name = "proto" const Name = "proto"
func init() { func init() {
encoding.RegisterCodec(codec{}) encoding.RegisterCodecV2(&codecV2{})
} }
// codec is a Codec implementation with protobuf. It is the default codec for gRPC. // codec is a CodecV2 implementation with protobuf. It is the default codec for
type codec struct{} // gRPC.
type codecV2 struct{}
func (codec) Marshal(v any) ([]byte, error) { func (c *codecV2) Marshal(v any) (data mem.BufferSlice, err error) {
vv := messageV2Of(v) vv := messageV2Of(v)
if vv == nil { if vv == nil {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) return nil, fmt.Errorf("proto: failed to marshal, message is %T, want proto.Message", v)
} }
return proto.Marshal(vv) size := proto.Size(vv)
if mem.IsBelowBufferPoolingThreshold(size) {
buf, err := proto.Marshal(vv)
if err != nil {
return nil, err
}
data = append(data, mem.SliceBuffer(buf))
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(size)
if _, err := (proto.MarshalOptions{}).MarshalAppend((*buf)[:0], vv); err != nil {
pool.Put(buf)
return nil, err
}
data = append(data, mem.NewBuffer(buf, pool))
}
return data, nil
} }
func (codec) Unmarshal(data []byte, v any) error { func (c *codecV2) Unmarshal(data mem.BufferSlice, v any) (err error) {
vv := messageV2Of(v) vv := messageV2Of(v)
if vv == nil { if vv == nil {
return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v)
} }
return proto.Unmarshal(data, vv) buf := data.MaterializeToBuffer(mem.DefaultBufferPool())
defer buf.Free()
// TODO: Upgrade proto.Unmarshal to support mem.BufferSlice. Right now, it's not
// really possible without a major overhaul of the proto package, but the
// vtprotobuf library may be able to support this.
return proto.Unmarshal(buf.ReadOnlyData(), vv)
} }
func messageV2Of(v any) proto.Message { func messageV2Of(v any) proto.Message {
@ -67,6 +91,6 @@ func messageV2Of(v any) proto.Message {
return nil return nil
} }
func (codec) Name() string { func (c *codecV2) Name() string {
return Name return Name
} }

View File

@ -0,0 +1,269 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package stats
import (
"maps"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
)
func init() {
internal.SnapshotMetricRegistryForTesting = snapshotMetricsRegistryForTesting
}
var logger = grpclog.Component("metrics-registry")
// DefaultMetrics are the default metrics registered through global metrics
// registry. This is written to at initialization time only, and is read only
// after initialization.
var DefaultMetrics = NewMetrics()
// MetricDescriptor is the data for a registered metric.
type MetricDescriptor struct {
// The name of this metric. This name must be unique across the whole binary
// (including any per call metrics). See
// https://github.com/grpc/proposal/blob/master/A79-non-per-call-metrics-architecture.md#metric-instrument-naming-conventions
// for metric naming conventions.
Name Metric
// The description of this metric.
Description string
// The unit (e.g. entries, seconds) of this metric.
Unit string
// The required label keys for this metric. These are intended to
// metrics emitted from a stats handler.
Labels []string
// The optional label keys for this metric. These are intended to attached
// to metrics emitted from a stats handler if configured.
OptionalLabels []string
// Whether this metric is on by default.
Default bool
// The type of metric. This is set by the metric registry, and not intended
// to be set by a component registering a metric.
Type MetricType
// Bounds are the bounds of this metric. This only applies to histogram
// metrics. If unset or set with length 0, stats handlers will fall back to
// default bounds.
Bounds []float64
}
// MetricType is the type of metric.
type MetricType int
// Type of metric supported by this instrument registry.
const (
MetricTypeIntCount MetricType = iota
MetricTypeFloatCount
MetricTypeIntHisto
MetricTypeFloatHisto
MetricTypeIntGauge
)
// Int64CountHandle is a typed handle for a int count metric. This handle
// is passed at the recording point in order to know which metric to record
// on.
type Int64CountHandle MetricDescriptor
// Descriptor returns the int64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 count value on the metrics recorder provided.
func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Count(h, incr, labels...)
}
// Float64CountHandle is a typed handle for a float count metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Float64CountHandle MetricDescriptor
// Descriptor returns the float64 count handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64CountHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 count value on the metrics recorder provided.
func (h *Float64CountHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Count(h, incr, labels...)
}
// Int64HistoHandle is a typed handle for an int histogram metric. This handle
// is passed at the recording point in order to know which metric to record on.
type Int64HistoHandle MetricDescriptor
// Descriptor returns the int64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64HistoHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Histo(h, incr, labels...)
}
// Float64HistoHandle is a typed handle for a float histogram metric. This
// handle is passed at the recording point in order to know which metric to
// record on.
type Float64HistoHandle MetricDescriptor
// Descriptor returns the float64 histo handle typecast to a pointer to a
// MetricDescriptor.
func (h *Float64HistoHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the float64 histo value on the metrics recorder provided.
func (h *Float64HistoHandle) Record(recorder MetricsRecorder, incr float64, labels ...string) {
recorder.RecordFloat64Histo(h, incr, labels...)
}
// Int64GaugeHandle is a typed handle for an int gauge metric. This handle is
// passed at the recording point in order to know which metric to record on.
type Int64GaugeHandle MetricDescriptor
// Descriptor returns the int64 gauge handle typecast to a pointer to a
// MetricDescriptor.
func (h *Int64GaugeHandle) Descriptor() *MetricDescriptor {
return (*MetricDescriptor)(h)
}
// Record records the int64 histo value on the metrics recorder provided.
func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels ...string) {
recorder.RecordInt64Gauge(h, incr, labels...)
}
// registeredMetrics are the registered metric descriptor names.
var registeredMetrics = make(map[Metric]bool)
// metricsRegistry contains all of the registered metrics.
//
// This is written to only at init time, and read only after that.
var metricsRegistry = make(map[Metric]*MetricDescriptor)
// DescriptorForMetric returns the MetricDescriptor from the global registry.
//
// Returns nil if MetricDescriptor not present.
func DescriptorForMetric(metric Metric) *MetricDescriptor {
return metricsRegistry[metric]
}
func registerMetric(name Metric, def bool) {
if registeredMetrics[name] {
logger.Fatalf("metric %v already registered", name)
}
registeredMetrics[name] = true
if def {
DefaultMetrics = DefaultMetrics.Add(name)
}
}
// RegisterInt64Count registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Count(descriptor MetricDescriptor) *Int64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64CountHandle)(descPtr)
}
// RegisterFloat64Count registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Count(descriptor MetricDescriptor) *Float64CountHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatCount
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64CountHandle)(descPtr)
}
// RegisterInt64Histo registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Histo(descriptor MetricDescriptor) *Int64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64HistoHandle)(descPtr)
}
// RegisterFloat64Histo registers the metric description onto the global
// registry. It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterFloat64Histo(descriptor MetricDescriptor) *Float64HistoHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeFloatHisto
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Float64HistoHandle)(descPtr)
}
// RegisterInt64Gauge registers the metric description onto the global registry.
// It returns a typed handle to use to recording data.
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple metrics are
// registered with the same name, this function will panic.
func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle {
registerMetric(descriptor.Name, descriptor.Default)
descriptor.Type = MetricTypeIntGauge
descPtr := &descriptor
metricsRegistry[descriptor.Name] = descPtr
return (*Int64GaugeHandle)(descPtr)
}
// snapshotMetricsRegistryForTesting snapshots the global data of the metrics
// registry. Returns a cleanup function that sets the metrics registry to its
// original state.
func snapshotMetricsRegistryForTesting() func() {
oldDefaultMetrics := DefaultMetrics
oldRegisteredMetrics := registeredMetrics
oldMetricsRegistry := metricsRegistry
registeredMetrics = make(map[Metric]bool)
metricsRegistry = make(map[Metric]*MetricDescriptor)
maps.Copy(registeredMetrics, registeredMetrics)
maps.Copy(metricsRegistry, metricsRegistry)
return func() {
DefaultMetrics = oldDefaultMetrics
registeredMetrics = oldRegisteredMetrics
metricsRegistry = oldMetricsRegistry
}
}

View File

@ -0,0 +1,114 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats contains experimental metrics/stats API's.
package stats
import "maps"
// MetricsRecorder records on metrics derived from metric registry.
type MetricsRecorder interface {
// RecordInt64Count records the measurement alongside labels on the int
// count associated with the provided handle.
RecordInt64Count(handle *Int64CountHandle, incr int64, labels ...string)
// RecordFloat64Count records the measurement alongside labels on the float
// count associated with the provided handle.
RecordFloat64Count(handle *Float64CountHandle, incr float64, labels ...string)
// RecordInt64Histo records the measurement alongside labels on the int
// histo associated with the provided handle.
RecordInt64Histo(handle *Int64HistoHandle, incr int64, labels ...string)
// RecordFloat64Histo records the measurement alongside labels on the float
// histo associated with the provided handle.
RecordFloat64Histo(handle *Float64HistoHandle, incr float64, labels ...string)
// RecordInt64Gauge records the measurement alongside labels on the int
// gauge associated with the provided handle.
RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string)
}
// Metric is an identifier for a metric.
type Metric string
// Metrics is a set of metrics to record. Once created, Metrics is immutable,
// however Add and Remove can make copies with specific metrics added or
// removed, respectively.
//
// Do not construct directly; use NewMetrics instead.
type Metrics struct {
// metrics are the set of metrics to initialize.
metrics map[Metric]bool
}
// NewMetrics returns a Metrics containing Metrics.
func NewMetrics(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for _, metric := range metrics {
newMetrics[metric] = true
}
return &Metrics{
metrics: newMetrics,
}
}
// Metrics returns the metrics set. The returned map is read-only and must not
// be modified.
func (m *Metrics) Metrics() map[Metric]bool {
return m.metrics
}
// Add adds the metrics to the metrics set and returns a new copy with the
// additional metrics.
func (m *Metrics) Add(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for metric := range m.metrics {
newMetrics[metric] = true
}
for _, metric := range metrics {
newMetrics[metric] = true
}
return &Metrics{
metrics: newMetrics,
}
}
// Join joins the metrics passed in with the metrics set, and returns a new copy
// with the merged metrics.
func (m *Metrics) Join(metrics *Metrics) *Metrics {
newMetrics := make(map[Metric]bool)
maps.Copy(newMetrics, m.metrics)
maps.Copy(newMetrics, metrics.metrics)
return &Metrics{
metrics: newMetrics,
}
}
// Remove removes the metrics from the metrics set and returns a new copy with
// the metrics removed.
func (m *Metrics) Remove(metrics ...Metric) *Metrics {
newMetrics := make(map[Metric]bool)
for metric := range m.metrics {
newMetrics[metric] = true
}
for _, metric := range metrics {
delete(newMetrics, metric)
}
return &Metrics{
metrics: newMetrics,
}
}

View File

@ -20,8 +20,6 @@ package grpclog
import ( import (
"fmt" "fmt"
"google.golang.org/grpc/internal/grpclog"
) )
// componentData records the settings for a component. // componentData records the settings for a component.
@ -33,22 +31,22 @@ var cache = map[string]*componentData{}
func (c *componentData) InfoDepth(depth int, args ...any) { func (c *componentData) InfoDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.InfoDepth(depth+1, args...) InfoDepth(depth+1, args...)
} }
func (c *componentData) WarningDepth(depth int, args ...any) { func (c *componentData) WarningDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.WarningDepth(depth+1, args...) WarningDepth(depth+1, args...)
} }
func (c *componentData) ErrorDepth(depth int, args ...any) { func (c *componentData) ErrorDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.ErrorDepth(depth+1, args...) ErrorDepth(depth+1, args...)
} }
func (c *componentData) FatalDepth(depth int, args ...any) { func (c *componentData) FatalDepth(depth int, args ...any) {
args = append([]any{"[" + string(c.name) + "]"}, args...) args = append([]any{"[" + string(c.name) + "]"}, args...)
grpclog.FatalDepth(depth+1, args...) FatalDepth(depth+1, args...)
} }
func (c *componentData) Info(args ...any) { func (c *componentData) Info(args ...any) {

View File

@ -18,18 +18,15 @@
// Package grpclog defines logging for grpc. // Package grpclog defines logging for grpc.
// //
// All logs in transport and grpclb packages only go to verbose level 2. // In the default logger, severity level can be set by environment variable
// All logs in other packages in grpc are logged in spite of the verbosity level. // GRPC_GO_LOG_SEVERITY_LEVEL, verbosity level can be set by
// // GRPC_GO_LOG_VERBOSITY_LEVEL.
// In the default logger, package grpclog
// severity level can be set by environment variable GRPC_GO_LOG_SEVERITY_LEVEL,
// verbosity level can be set by GRPC_GO_LOG_VERBOSITY_LEVEL.
package grpclog // import "google.golang.org/grpc/grpclog"
import ( import (
"os" "os"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/grpclog/internal"
) )
func init() { func init() {
@ -38,58 +35,58 @@ func init() {
// V reports whether verbosity level l is at least the requested verbose level. // V reports whether verbosity level l is at least the requested verbose level.
func V(l int) bool { func V(l int) bool {
return grpclog.Logger.V(l) return internal.LoggerV2Impl.V(l)
} }
// Info logs to the INFO log. // Info logs to the INFO log.
func Info(args ...any) { func Info(args ...any) {
grpclog.Logger.Info(args...) internal.LoggerV2Impl.Info(args...)
} }
// Infof logs to the INFO log. Arguments are handled in the manner of fmt.Printf. // Infof logs to the INFO log. Arguments are handled in the manner of fmt.Printf.
func Infof(format string, args ...any) { func Infof(format string, args ...any) {
grpclog.Logger.Infof(format, args...) internal.LoggerV2Impl.Infof(format, args...)
} }
// Infoln logs to the INFO log. Arguments are handled in the manner of fmt.Println. // Infoln logs to the INFO log. Arguments are handled in the manner of fmt.Println.
func Infoln(args ...any) { func Infoln(args ...any) {
grpclog.Logger.Infoln(args...) internal.LoggerV2Impl.Infoln(args...)
} }
// Warning logs to the WARNING log. // Warning logs to the WARNING log.
func Warning(args ...any) { func Warning(args ...any) {
grpclog.Logger.Warning(args...) internal.LoggerV2Impl.Warning(args...)
} }
// Warningf logs to the WARNING log. Arguments are handled in the manner of fmt.Printf. // Warningf logs to the WARNING log. Arguments are handled in the manner of fmt.Printf.
func Warningf(format string, args ...any) { func Warningf(format string, args ...any) {
grpclog.Logger.Warningf(format, args...) internal.LoggerV2Impl.Warningf(format, args...)
} }
// Warningln logs to the WARNING log. Arguments are handled in the manner of fmt.Println. // Warningln logs to the WARNING log. Arguments are handled in the manner of fmt.Println.
func Warningln(args ...any) { func Warningln(args ...any) {
grpclog.Logger.Warningln(args...) internal.LoggerV2Impl.Warningln(args...)
} }
// Error logs to the ERROR log. // Error logs to the ERROR log.
func Error(args ...any) { func Error(args ...any) {
grpclog.Logger.Error(args...) internal.LoggerV2Impl.Error(args...)
} }
// Errorf logs to the ERROR log. Arguments are handled in the manner of fmt.Printf. // Errorf logs to the ERROR log. Arguments are handled in the manner of fmt.Printf.
func Errorf(format string, args ...any) { func Errorf(format string, args ...any) {
grpclog.Logger.Errorf(format, args...) internal.LoggerV2Impl.Errorf(format, args...)
} }
// Errorln logs to the ERROR log. Arguments are handled in the manner of fmt.Println. // Errorln logs to the ERROR log. Arguments are handled in the manner of fmt.Println.
func Errorln(args ...any) { func Errorln(args ...any) {
grpclog.Logger.Errorln(args...) internal.LoggerV2Impl.Errorln(args...)
} }
// Fatal logs to the FATAL log. Arguments are handled in the manner of fmt.Print. // Fatal logs to the FATAL log. Arguments are handled in the manner of fmt.Print.
// It calls os.Exit() with exit code 1. // It calls os.Exit() with exit code 1.
func Fatal(args ...any) { func Fatal(args ...any) {
grpclog.Logger.Fatal(args...) internal.LoggerV2Impl.Fatal(args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
@ -97,15 +94,15 @@ func Fatal(args ...any) {
// Fatalf logs to the FATAL log. Arguments are handled in the manner of fmt.Printf. // Fatalf logs to the FATAL log. Arguments are handled in the manner of fmt.Printf.
// It calls os.Exit() with exit code 1. // It calls os.Exit() with exit code 1.
func Fatalf(format string, args ...any) { func Fatalf(format string, args ...any) {
grpclog.Logger.Fatalf(format, args...) internal.LoggerV2Impl.Fatalf(format, args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
// Fatalln logs to the FATAL log. Arguments are handled in the manner of fmt.Println. // Fatalln logs to the FATAL log. Arguments are handled in the manner of fmt.Println.
// It calle os.Exit()) with exit code 1. // It calls os.Exit() with exit code 1.
func Fatalln(args ...any) { func Fatalln(args ...any) {
grpclog.Logger.Fatalln(args...) internal.LoggerV2Impl.Fatalln(args...)
// Make sure fatal logs will exit. // Make sure fatal logs will exit.
os.Exit(1) os.Exit(1)
} }
@ -114,19 +111,76 @@ func Fatalln(args ...any) {
// //
// Deprecated: use Info. // Deprecated: use Info.
func Print(args ...any) { func Print(args ...any) {
grpclog.Logger.Info(args...) internal.LoggerV2Impl.Info(args...)
} }
// Printf prints to the logger. Arguments are handled in the manner of fmt.Printf. // Printf prints to the logger. Arguments are handled in the manner of fmt.Printf.
// //
// Deprecated: use Infof. // Deprecated: use Infof.
func Printf(format string, args ...any) { func Printf(format string, args ...any) {
grpclog.Logger.Infof(format, args...) internal.LoggerV2Impl.Infof(format, args...)
} }
// Println prints to the logger. Arguments are handled in the manner of fmt.Println. // Println prints to the logger. Arguments are handled in the manner of fmt.Println.
// //
// Deprecated: use Infoln. // Deprecated: use Infoln.
func Println(args ...any) { func Println(args ...any) {
grpclog.Logger.Infoln(args...) internal.LoggerV2Impl.Infoln(args...)
}
// InfoDepth logs to the INFO log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func InfoDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.InfoDepth(depth, args...)
} else {
internal.LoggerV2Impl.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WarningDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.WarningDepth(depth, args...)
} else {
internal.LoggerV2Impl.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ErrorDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.ErrorDepth(depth, args...)
} else {
internal.LoggerV2Impl.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func FatalDepth(depth int, args ...any) {
if internal.DepthLoggerV2Impl != nil {
internal.DepthLoggerV2Impl.FatalDepth(depth, args...)
} else {
internal.LoggerV2Impl.Fatalln(args...)
}
os.Exit(1)
} }

View File

@ -1,9 +1,6 @@
//go:build !linux
// +build !linux
/* /*
* *
* Copyright 2018 gRPC authors. * Copyright 2024 gRPC authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,9 +16,11 @@
* *
*/ */
package channelz // Package internal contains functionality internal to the grpclog package.
package internal
// GetSocketOption gets the socket option info of the conn. // LoggerV2Impl is the logger used for the non-depth log functions.
func GetSocketOption(c any) *SocketOptionData { var LoggerV2Impl LoggerV2
return nil
} // DepthLoggerV2Impl is the logger used for the depth log functions.
var DepthLoggerV2Impl DepthLoggerV2

View File

@ -0,0 +1,87 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package internal
// Logger mimics golang's standard Logger as an interface.
//
// Deprecated: use LoggerV2.
type Logger interface {
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
// LoggerWrapper wraps Logger into a LoggerV2.
type LoggerWrapper struct {
Logger
}
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Info(args ...any) {
l.Logger.Print(args...)
}
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Infoln(args ...any) {
l.Logger.Println(args...)
}
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Infof(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Warning(args ...any) {
l.Logger.Print(args...)
}
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Warningln(args ...any) {
l.Logger.Println(args...)
}
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Warningf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
func (l *LoggerWrapper) Error(args ...any) {
l.Logger.Print(args...)
}
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
func (l *LoggerWrapper) Errorln(args ...any) {
l.Logger.Println(args...)
}
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
func (l *LoggerWrapper) Errorf(format string, args ...any) {
l.Logger.Printf(format, args...)
}
// V reports whether verbosity level l is at least the requested verbose level.
func (*LoggerWrapper) V(l int) bool {
// Returns true for all verbose level.
return true
}

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2020 gRPC authors. * Copyright 2024 gRPC authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,59 +16,17 @@
* *
*/ */
// Package grpclog (internal) defines depth logging for grpc. package internal
package grpclog
import ( import (
"encoding/json"
"fmt"
"io"
"log"
"os" "os"
) )
// Logger is the logger used for the non-depth log functions.
var Logger LoggerV2
// DepthLogger is the logger used for the depth log functions.
var DepthLogger DepthLoggerV2
// InfoDepth logs to the INFO log at the specified depth.
func InfoDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.InfoDepth(depth, args...)
} else {
Logger.Infoln(args...)
}
}
// WarningDepth logs to the WARNING log at the specified depth.
func WarningDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.WarningDepth(depth, args...)
} else {
Logger.Warningln(args...)
}
}
// ErrorDepth logs to the ERROR log at the specified depth.
func ErrorDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.ErrorDepth(depth, args...)
} else {
Logger.Errorln(args...)
}
}
// FatalDepth logs to the FATAL log at the specified depth.
func FatalDepth(depth int, args ...any) {
if DepthLogger != nil {
DepthLogger.FatalDepth(depth, args...)
} else {
Logger.Fatalln(args...)
}
os.Exit(1)
}
// LoggerV2 does underlying logging work for grpclog. // LoggerV2 does underlying logging work for grpclog.
// This is a copy of the LoggerV2 defined in the external grpclog package. It
// is defined here to avoid a circular dependency.
type LoggerV2 interface { type LoggerV2 interface {
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. // Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any) Info(args ...any)
@ -107,14 +65,13 @@ type LoggerV2 interface {
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements // DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack // DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore. // depth set for trivial functions the logger may ignore.
// This is a copy of the DepthLoggerV2 defined in the external grpclog package.
// It is defined here to avoid a circular dependency.
// //
// # Experimental // # Experimental
// //
// Notice: This type is EXPERIMENTAL and may be changed or removed in a // Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
type DepthLoggerV2 interface { type DepthLoggerV2 interface {
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println. // InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any) InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println. // WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
@ -124,3 +81,124 @@ type DepthLoggerV2 interface {
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println. // FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any) FatalDepth(depth int, args ...any)
} }
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
}
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s))
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) Info(args ...any) {
g.output(infoLog, fmt.Sprint(args...))
}
func (g *loggerT) Infoln(args ...any) {
g.output(infoLog, fmt.Sprintln(args...))
}
func (g *loggerT) Infof(format string, args ...any) {
g.output(infoLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Warning(args ...any) {
g.output(warningLog, fmt.Sprint(args...))
}
func (g *loggerT) Warningln(args ...any) {
g.output(warningLog, fmt.Sprintln(args...))
}
func (g *loggerT) Warningf(format string, args ...any) {
g.output(warningLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Error(args ...any) {
g.output(errorLog, fmt.Sprint(args...))
}
func (g *loggerT) Errorln(args ...any) {
g.output(errorLog, fmt.Sprintln(args...))
}
func (g *loggerT) Errorf(format string, args ...any) {
g.output(errorLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Fatal(args ...any) {
g.output(fatalLog, fmt.Sprint(args...))
os.Exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.output(fatalLog, fmt.Sprintln(args...))
os.Exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.output(fatalLog, fmt.Sprintf(format, args...))
os.Exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// LoggerV2Config configures the LoggerV2 implementation.
type LoggerV2Config struct {
// Verbosity sets the verbosity level of the logger.
Verbosity int
// FormatJSON controls whether the logger should output logs in JSON format.
FormatJSON bool
}
// NewLoggerV2 creates a new LoggerV2 instance with the provided configuration.
// The infoW, warningW, and errorW writers are used to write log messages of
// different severity levels.
func NewLoggerV2(infoW, warningW, errorW io.Writer, c LoggerV2Config) LoggerV2 {
var m []*log.Logger
flag := log.LstdFlags
if c.FormatJSON {
flag = 0
}
m = append(m, log.New(infoW, "", flag))
m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag))
ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal.
m = append(m, log.New(ew, "", flag))
m = append(m, log.New(ew, "", flag))
return &loggerT{m: m, v: c.Verbosity, jsonFormat: c.FormatJSON}
}

View File

@ -18,70 +18,17 @@
package grpclog package grpclog
import "google.golang.org/grpc/internal/grpclog" import "google.golang.org/grpc/grpclog/internal"
// Logger mimics golang's standard Logger as an interface. // Logger mimics golang's standard Logger as an interface.
// //
// Deprecated: use LoggerV2. // Deprecated: use LoggerV2.
type Logger interface { type Logger internal.Logger
Fatal(args ...any)
Fatalf(format string, args ...any)
Fatalln(args ...any)
Print(args ...any)
Printf(format string, args ...any)
Println(args ...any)
}
// SetLogger sets the logger that is used in grpc. Call only from // SetLogger sets the logger that is used in grpc. Call only from
// init() functions. // init() functions.
// //
// Deprecated: use SetLoggerV2. // Deprecated: use SetLoggerV2.
func SetLogger(l Logger) { func SetLogger(l Logger) {
grpclog.Logger = &loggerWrapper{Logger: l} internal.LoggerV2Impl = &internal.LoggerWrapper{Logger: l}
}
// loggerWrapper wraps Logger into a LoggerV2.
type loggerWrapper struct {
Logger
}
func (g *loggerWrapper) Info(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Infoln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Infof(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Warning(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Warningln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Warningf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) Error(args ...any) {
g.Logger.Print(args...)
}
func (g *loggerWrapper) Errorln(args ...any) {
g.Logger.Println(args...)
}
func (g *loggerWrapper) Errorf(format string, args ...any) {
g.Logger.Printf(format, args...)
}
func (g *loggerWrapper) V(l int) bool {
// Returns true for all verbose level.
return true
} }

View File

@ -19,52 +19,16 @@
package grpclog package grpclog
import ( import (
"encoding/json"
"fmt"
"io" "io"
"log"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/grpclog/internal"
) )
// LoggerV2 does underlying logging work for grpclog. // LoggerV2 does underlying logging work for grpclog.
type LoggerV2 interface { type LoggerV2 internal.LoggerV2
// Info logs to INFO log. Arguments are handled in the manner of fmt.Print.
Info(args ...any)
// Infoln logs to INFO log. Arguments are handled in the manner of fmt.Println.
Infoln(args ...any)
// Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf.
Infof(format string, args ...any)
// Warning logs to WARNING log. Arguments are handled in the manner of fmt.Print.
Warning(args ...any)
// Warningln logs to WARNING log. Arguments are handled in the manner of fmt.Println.
Warningln(args ...any)
// Warningf logs to WARNING log. Arguments are handled in the manner of fmt.Printf.
Warningf(format string, args ...any)
// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print.
Error(args ...any)
// Errorln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
Errorln(args ...any)
// Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
Errorf(format string, args ...any)
// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatal(args ...any)
// Fatalln logs to ERROR log. Arguments are handled in the manner of fmt.Println.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalln(args ...any)
// Fatalf logs to ERROR log. Arguments are handled in the manner of fmt.Printf.
// gRPC ensures that all Fatal logs will exit with os.Exit(1).
// Implementations may also call os.Exit() with a non-zero exit code.
Fatalf(format string, args ...any)
// V reports whether verbosity level l is at least the requested verbose level.
V(l int) bool
}
// SetLoggerV2 sets logger that is used in grpc to a V2 logger. // SetLoggerV2 sets logger that is used in grpc to a V2 logger.
// Not mutex-protected, should be called before any gRPC functions. // Not mutex-protected, should be called before any gRPC functions.
@ -72,34 +36,8 @@ func SetLoggerV2(l LoggerV2) {
if _, ok := l.(*componentData); ok { if _, ok := l.(*componentData); ok {
panic("cannot use component logger as grpclog logger") panic("cannot use component logger as grpclog logger")
} }
grpclog.Logger = l internal.LoggerV2Impl = l
grpclog.DepthLogger, _ = l.(grpclog.DepthLoggerV2) internal.DepthLoggerV2Impl, _ = l.(internal.DepthLoggerV2)
}
const (
// infoLog indicates Info severity.
infoLog int = iota
// warningLog indicates Warning severity.
warningLog
// errorLog indicates Error severity.
errorLog
// fatalLog indicates Fatal severity.
fatalLog
)
// severityName contains the string representation of each severity.
var severityName = []string{
infoLog: "INFO",
warningLog: "WARNING",
errorLog: "ERROR",
fatalLog: "FATAL",
}
// loggerT is the default logger used by grpclog.
type loggerT struct {
m []*log.Logger
v int
jsonFormat bool
} }
// NewLoggerV2 creates a loggerV2 with the provided writers. // NewLoggerV2 creates a loggerV2 with the provided writers.
@ -108,32 +46,13 @@ type loggerT struct {
// Warning logs will be written to warningW and infoW. // Warning logs will be written to warningW and infoW.
// Info logs will be written to infoW. // Info logs will be written to infoW.
func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 { func NewLoggerV2(infoW, warningW, errorW io.Writer) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{}) return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{})
} }
// NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and // NewLoggerV2WithVerbosity creates a loggerV2 with the provided writers and
// verbosity level. // verbosity level.
func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 { func NewLoggerV2WithVerbosity(infoW, warningW, errorW io.Writer, v int) LoggerV2 {
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{verbose: v}) return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{Verbosity: v})
}
type loggerV2Config struct {
verbose int
jsonFormat bool
}
func newLoggerV2WithConfig(infoW, warningW, errorW io.Writer, c loggerV2Config) LoggerV2 {
var m []*log.Logger
flag := log.LstdFlags
if c.jsonFormat {
flag = 0
}
m = append(m, log.New(infoW, "", flag))
m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag))
ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal.
m = append(m, log.New(ew, "", flag))
m = append(m, log.New(ew, "", flag))
return &loggerT{m: m, v: c.verbose, jsonFormat: c.jsonFormat}
} }
// newLoggerV2 creates a loggerV2 to be used as default logger. // newLoggerV2 creates a loggerV2 to be used as default logger.
@ -161,82 +80,12 @@ func newLoggerV2() LoggerV2 {
jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json") jsonFormat := strings.EqualFold(os.Getenv("GRPC_GO_LOG_FORMATTER"), "json")
return newLoggerV2WithConfig(infoW, warningW, errorW, loggerV2Config{ return internal.NewLoggerV2(infoW, warningW, errorW, internal.LoggerV2Config{
verbose: v, Verbosity: v,
jsonFormat: jsonFormat, FormatJSON: jsonFormat,
}) })
} }
func (g *loggerT) output(severity int, s string) {
sevStr := severityName[severity]
if !g.jsonFormat {
g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s))
return
}
// TODO: we can also include the logging component, but that needs more
// (API) changes.
b, _ := json.Marshal(map[string]string{
"severity": sevStr,
"message": s,
})
g.m[severity].Output(2, string(b))
}
func (g *loggerT) Info(args ...any) {
g.output(infoLog, fmt.Sprint(args...))
}
func (g *loggerT) Infoln(args ...any) {
g.output(infoLog, fmt.Sprintln(args...))
}
func (g *loggerT) Infof(format string, args ...any) {
g.output(infoLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Warning(args ...any) {
g.output(warningLog, fmt.Sprint(args...))
}
func (g *loggerT) Warningln(args ...any) {
g.output(warningLog, fmt.Sprintln(args...))
}
func (g *loggerT) Warningf(format string, args ...any) {
g.output(warningLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Error(args ...any) {
g.output(errorLog, fmt.Sprint(args...))
}
func (g *loggerT) Errorln(args ...any) {
g.output(errorLog, fmt.Sprintln(args...))
}
func (g *loggerT) Errorf(format string, args ...any) {
g.output(errorLog, fmt.Sprintf(format, args...))
}
func (g *loggerT) Fatal(args ...any) {
g.output(fatalLog, fmt.Sprint(args...))
os.Exit(1)
}
func (g *loggerT) Fatalln(args ...any) {
g.output(fatalLog, fmt.Sprintln(args...))
os.Exit(1)
}
func (g *loggerT) Fatalf(format string, args ...any) {
g.output(fatalLog, fmt.Sprintf(format, args...))
os.Exit(1)
}
func (g *loggerT) V(l int) bool {
return l <= g.v
}
// DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements // DepthLoggerV2 logs at a specified call frame. If a LoggerV2 also implements
// DepthLoggerV2, the below functions will be called with the appropriate stack // DepthLoggerV2, the below functions will be called with the appropriate stack
// depth set for trivial functions the logger may ignore. // depth set for trivial functions the logger may ignore.
@ -245,14 +94,4 @@ func (g *loggerT) V(l int) bool {
// //
// Notice: This type is EXPERIMENTAL and may be changed or removed in a // Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
type DepthLoggerV2 interface { type DepthLoggerV2 internal.DepthLoggerV2
LoggerV2
// InfoDepth logs to INFO log at the specified depth. Arguments are handled in the manner of fmt.Println.
InfoDepth(depth int, args ...any)
// WarningDepth logs to WARNING log at the specified depth. Arguments are handled in the manner of fmt.Println.
WarningDepth(depth int, args ...any)
// ErrorDepth logs to ERROR log at the specified depth. Arguments are handled in the manner of fmt.Println.
ErrorDepth(depth int, args ...any)
// FatalDepth logs to FATAL log at the specified depth. Arguments are handled in the manner of fmt.Println.
FatalDepth(depth int, args ...any)
}

View File

@ -17,8 +17,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.32.0 // protoc-gen-go v1.34.1
// protoc v4.25.2 // protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
package grpc_health_v1 package grpc_health_v1

View File

@ -17,8 +17,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.3.0 // - protoc-gen-go-grpc v1.5.1
// - protoc v4.25.2 // - protoc v5.27.1
// source: grpc/health/v1/health.proto // source: grpc/health/v1/health.proto
package grpc_health_v1 package grpc_health_v1
@ -32,8 +32,8 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later. // Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion7 const _ = grpc.SupportPackageIsVersion9
const ( const (
Health_Check_FullMethodName = "/grpc.health.v1.Health/Check" Health_Check_FullMethodName = "/grpc.health.v1.Health/Check"
@ -43,6 +43,10 @@ const (
// HealthClient is the client API for Health service. // HealthClient is the client API for Health service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthClient interface { type HealthClient interface {
// Check gets the health of the specified service. If the requested service // Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does // is unknown, the call will fail with status NOT_FOUND. If the caller does
@ -69,7 +73,7 @@ type HealthClient interface {
// should assume this method is not supported and should not retry the // should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK), // call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff. // clients should retry the call with appropriate exponential backoff.
Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error)
} }
type healthClient struct { type healthClient struct {
@ -81,20 +85,22 @@ func NewHealthClient(cc grpc.ClientConnInterface) HealthClient {
} }
func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) { func (c *healthClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(HealthCheckResponse) out := new(HealthCheckResponse)
err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, opts...) err := c.cc.Invoke(ctx, Health_Check_FullMethodName, in, out, cOpts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
} }
func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (Health_WatchClient, error) { func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[HealthCheckResponse], error) {
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, opts...) cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &Health_ServiceDesc.Streams[0], Health_Watch_FullMethodName, cOpts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
x := &healthWatchClient{stream} x := &grpc.GenericClientStream[HealthCheckRequest, HealthCheckResponse]{ClientStream: stream}
if err := x.ClientStream.SendMsg(in); err != nil { if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err return nil, err
} }
@ -104,26 +110,16 @@ func (c *healthClient) Watch(ctx context.Context, in *HealthCheckRequest, opts .
return x, nil return x, nil
} }
type Health_WatchClient interface { // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
Recv() (*HealthCheckResponse, error) type Health_WatchClient = grpc.ServerStreamingClient[HealthCheckResponse]
grpc.ClientStream
}
type healthWatchClient struct {
grpc.ClientStream
}
func (x *healthWatchClient) Recv() (*HealthCheckResponse, error) {
m := new(HealthCheckResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// HealthServer is the server API for Health service. // HealthServer is the server API for Health service.
// All implementations should embed UnimplementedHealthServer // All implementations should embed UnimplementedHealthServer
// for forward compatibility // for forward compatibility.
//
// Health is gRPC's mechanism for checking whether a server is able to handle
// RPCs. Its semantics are documented in
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md.
type HealthServer interface { type HealthServer interface {
// Check gets the health of the specified service. If the requested service // Check gets the health of the specified service. If the requested service
// is unknown, the call will fail with status NOT_FOUND. If the caller does // is unknown, the call will fail with status NOT_FOUND. If the caller does
@ -150,19 +146,23 @@ type HealthServer interface {
// should assume this method is not supported and should not retry the // should assume this method is not supported and should not retry the
// call. If the call terminates with any other status (including OK), // call. If the call terminates with any other status (including OK),
// clients should retry the call with appropriate exponential backoff. // clients should retry the call with appropriate exponential backoff.
Watch(*HealthCheckRequest, Health_WatchServer) error Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error
} }
// UnimplementedHealthServer should be embedded to have forward compatible implementations. // UnimplementedHealthServer should be embedded to have
type UnimplementedHealthServer struct { // forward compatible implementations.
} //
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedHealthServer struct{}
func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) { func (UnimplementedHealthServer) Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Check not implemented") return nil, status.Errorf(codes.Unimplemented, "method Check not implemented")
} }
func (UnimplementedHealthServer) Watch(*HealthCheckRequest, Health_WatchServer) error { func (UnimplementedHealthServer) Watch(*HealthCheckRequest, grpc.ServerStreamingServer[HealthCheckResponse]) error {
return status.Errorf(codes.Unimplemented, "method Watch not implemented") return status.Errorf(codes.Unimplemented, "method Watch not implemented")
} }
func (UnimplementedHealthServer) testEmbeddedByValue() {}
// UnsafeHealthServer may be embedded to opt out of forward compatibility for this service. // UnsafeHealthServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to HealthServer will // Use of this interface is not recommended, as added methods to HealthServer will
@ -172,6 +172,13 @@ type UnsafeHealthServer interface {
} }
func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) { func RegisterHealthServer(s grpc.ServiceRegistrar, srv HealthServer) {
// If the following call panics, it indicates UnimplementedHealthServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&Health_ServiceDesc, srv) s.RegisterService(&Health_ServiceDesc, srv)
} }
@ -198,21 +205,11 @@ func _Health_Watch_Handler(srv interface{}, stream grpc.ServerStream) error {
if err := stream.RecvMsg(m); err != nil { if err := stream.RecvMsg(m); err != nil {
return err return err
} }
return srv.(HealthServer).Watch(m, &healthWatchServer{stream}) return srv.(HealthServer).Watch(m, &grpc.GenericServerStream[HealthCheckRequest, HealthCheckResponse]{ServerStream: stream})
} }
type Health_WatchServer interface { // This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
Send(*HealthCheckResponse) error type Health_WatchServer = grpc.ServerStreamingServer[HealthCheckResponse]
grpc.ServerStream
}
type healthWatchServer struct {
grpc.ServerStream
}
func (x *healthWatchServer) Send(m *HealthCheckResponse) error {
return x.ServerStream.SendMsg(m)
}
// Health_ServiceDesc is the grpc.ServiceDesc for Health service. // Health_ServiceDesc is the grpc.ServiceDesc for Health service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,

View File

@ -25,10 +25,10 @@ package backoff
import ( import (
"context" "context"
"errors" "errors"
"math/rand"
"time" "time"
grpcbackoff "google.golang.org/grpc/backoff" grpcbackoff "google.golang.org/grpc/backoff"
"google.golang.org/grpc/internal/grpcrand"
) )
// Strategy defines the methodology for backing off after a grpc connection // Strategy defines the methodology for backing off after a grpc connection
@ -67,7 +67,7 @@ func (bc Exponential) Backoff(retries int) time.Duration {
} }
// Randomize backoff delays so that if a cluster of requests start at // Randomize backoff delays so that if a cluster of requests start at
// the same time, they won't operate in lockstep. // the same time, they won't operate in lockstep.
backoff *= 1 + bc.Config.Jitter*(grpcrand.Float64()*2-1) backoff *= 1 + bc.Config.Jitter*(rand.Float64()*2-1)
if backoff < 0 { if backoff < 0 {
return 0 return 0
} }

View File

@ -0,0 +1,82 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package gracefulswitch
import (
"encoding/json"
"fmt"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/serviceconfig"
)
type lbConfig struct {
serviceconfig.LoadBalancingConfig
childBuilder balancer.Builder
childConfig serviceconfig.LoadBalancingConfig
}
func ChildName(l serviceconfig.LoadBalancingConfig) string {
return l.(*lbConfig).childBuilder.Name()
}
// ParseConfig parses a child config list and returns a LB config for the
// gracefulswitch Balancer.
//
// cfg is expected to be a json.RawMessage containing a JSON array of LB policy
// names + configs as the format of the "loadBalancingConfig" field in
// ServiceConfig. It returns a type that should be passed to
// UpdateClientConnState in the BalancerConfig field.
func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var lbCfg []map[string]json.RawMessage
if err := json.Unmarshal(cfg, &lbCfg); err != nil {
return nil, err
}
for i, e := range lbCfg {
if len(e) != 1 {
return nil, fmt.Errorf("expected a JSON struct with one entry; received entry %v at index %d", e, i)
}
var name string
var jsonCfg json.RawMessage
for name, jsonCfg = range e {
}
builder := balancer.Get(name)
if builder == nil {
// Skip unregistered balancer names.
continue
}
parser, ok := builder.(balancer.ConfigParser)
if !ok {
// This is a valid child with no config.
return &lbConfig{childBuilder: builder}, nil
}
cfg, err := parser.ParseConfig(jsonCfg)
if err != nil {
return nil, fmt.Errorf("error parsing config for policy %q: %v", name, err)
}
return &lbConfig{childBuilder: builder, childConfig: cfg}, nil
}
return nil, fmt.Errorf("no supported policies found in config: %v", string(cfg))
}

View File

@ -94,13 +94,22 @@ func (gsb *Balancer) balancerCurrentOrPending(bw *balancerWrapper) bool {
// process is not complete when this method returns. This method must be called // process is not complete when this method returns. This method must be called
// synchronously alongside the rest of the balancer.Balancer methods this // synchronously alongside the rest of the balancer.Balancer methods this
// Graceful Switch Balancer implements. // Graceful Switch Balancer implements.
//
// Deprecated: use ParseConfig and pass a parsed config to UpdateClientConnState
// to cause the Balancer to automatically change to the new child when necessary.
func (gsb *Balancer) SwitchTo(builder balancer.Builder) error { func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
_, err := gsb.switchTo(builder)
return err
}
func (gsb *Balancer) switchTo(builder balancer.Builder) (*balancerWrapper, error) {
gsb.mu.Lock() gsb.mu.Lock()
if gsb.closed { if gsb.closed {
gsb.mu.Unlock() gsb.mu.Unlock()
return errBalancerClosed return nil, errBalancerClosed
} }
bw := &balancerWrapper{ bw := &balancerWrapper{
builder: builder,
gsb: gsb, gsb: gsb,
lastState: balancer.State{ lastState: balancer.State{
ConnectivityState: connectivity.Connecting, ConnectivityState: connectivity.Connecting,
@ -129,7 +138,7 @@ func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
gsb.balancerCurrent = nil gsb.balancerCurrent = nil
} }
gsb.mu.Unlock() gsb.mu.Unlock()
return balancer.ErrBadResolverState return nil, balancer.ErrBadResolverState
} }
// This write doesn't need to take gsb.mu because this field never gets read // This write doesn't need to take gsb.mu because this field never gets read
@ -138,7 +147,7 @@ func (gsb *Balancer) SwitchTo(builder balancer.Builder) error {
// bw.Balancer field will never be forwarded to until this SwitchTo() // bw.Balancer field will never be forwarded to until this SwitchTo()
// function returns. // function returns.
bw.Balancer = newBalancer bw.Balancer = newBalancer
return nil return bw, nil
} }
// Returns nil if the graceful switch balancer is closed. // Returns nil if the graceful switch balancer is closed.
@ -152,12 +161,32 @@ func (gsb *Balancer) latestBalancer() *balancerWrapper {
} }
// UpdateClientConnState forwards the update to the latest balancer created. // UpdateClientConnState forwards the update to the latest balancer created.
//
// If the state's BalancerConfig is the config returned by a call to
// gracefulswitch.ParseConfig, then this function will automatically SwitchTo
// the balancer indicated by the config before forwarding its config to it, if
// necessary.
func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error { func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error {
// The resolver data is only relevant to the most recent LB Policy. // The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer() balToUpdate := gsb.latestBalancer()
gsbCfg, ok := state.BalancerConfig.(*lbConfig)
if ok {
// Switch to the child in the config unless it is already active.
if balToUpdate == nil || gsbCfg.childBuilder.Name() != balToUpdate.builder.Name() {
var err error
balToUpdate, err = gsb.switchTo(gsbCfg.childBuilder)
if err != nil {
return fmt.Errorf("could not switch to new child balancer: %w", err)
}
}
// Unwrap the child balancer's config.
state.BalancerConfig = gsbCfg.childConfig
}
if balToUpdate == nil { if balToUpdate == nil {
return errBalancerClosed return errBalancerClosed
} }
// Perform this call without gsb.mu to prevent deadlocks if the child calls // Perform this call without gsb.mu to prevent deadlocks if the child calls
// back into the channel. The latest balancer can never be closed during a // back into the channel. The latest balancer can never be closed during a
// call from the channel, even without gsb.mu held. // call from the channel, even without gsb.mu held.
@ -169,6 +198,10 @@ func (gsb *Balancer) ResolverError(err error) {
// The resolver data is only relevant to the most recent LB Policy. // The resolver data is only relevant to the most recent LB Policy.
balToUpdate := gsb.latestBalancer() balToUpdate := gsb.latestBalancer()
if balToUpdate == nil { if balToUpdate == nil {
gsb.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: base.NewErrPicker(err),
})
return return
} }
// Perform this call without gsb.mu to prevent deadlocks if the child calls // Perform this call without gsb.mu to prevent deadlocks if the child calls
@ -262,6 +295,7 @@ func (gsb *Balancer) Close() {
type balancerWrapper struct { type balancerWrapper struct {
balancer.Balancer balancer.Balancer
gsb *Balancer gsb *Balancer
builder balancer.Builder
lastState balancer.State lastState balancer.State
subconns map[balancer.SubConn]bool // subconns created by this balancer subconns map[balancer.SubConn]bool // subconns created by this balancer

View File

@ -65,7 +65,7 @@ type TruncatingMethodLogger struct {
callID uint64 callID uint64
idWithinCallGen *callIDGenerator idWithinCallGen *callIDGenerator
sink Sink // TODO(blog): make this plugable. sink Sink // TODO(blog): make this pluggable.
} }
// NewTruncatingMethodLogger returns a new truncating method logger. // NewTruncatingMethodLogger returns a new truncating method logger.
@ -80,7 +80,7 @@ func NewTruncatingMethodLogger(h, m uint64) *TruncatingMethodLogger {
callID: idGen.next(), callID: idGen.next(),
idWithinCallGen: &callIDGenerator{}, idWithinCallGen: &callIDGenerator{},
sink: DefaultSink, // TODO(blog): make it plugable. sink: DefaultSink, // TODO(blog): make it pluggable.
} }
} }
@ -397,7 +397,7 @@ func metadataKeyOmit(key string) bool {
switch key { switch key {
case "lb-token", ":path", ":authority", "content-encoding", "content-type", "user-agent", "te": case "lb-token", ":path", ":authority", "content-encoding", "content-type", "user-agent", "te":
return true return true
case "grpc-trace-bin": // grpc-trace-bin is special because it's visiable to users. case "grpc-trace-bin": // grpc-trace-bin is special because it's visible to users.
return false return false
} }
return strings.HasPrefix(key, "grpc-") return strings.HasPrefix(key, "grpc-")

View File

@ -0,0 +1,255 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
"google.golang.org/grpc/connectivity"
)
// Channel represents a channel within channelz, which includes metrics and
// internal channelz data, such as channelz id, child list, etc.
type Channel struct {
Entity
// ID is the channelz id of this channel.
ID int64
// RefName is the human readable reference string of this channel.
RefName string
closeCalled bool
nestedChans map[int64]string
subChans map[int64]string
Parent *Channel
trace *ChannelTrace
// traceRefCount is the number of trace events that reference this channel.
// Non-zero traceRefCount means the trace of this channel cannot be deleted.
traceRefCount int32
ChannelMetrics ChannelMetrics
}
// Implemented to make Channel implement the Identifier interface used for
// nesting.
func (c *Channel) channelzIdentifier() {}
func (c *Channel) String() string {
if c.Parent == nil {
return fmt.Sprintf("Channel #%d", c.ID)
}
return fmt.Sprintf("%s Channel #%d", c.Parent, c.ID)
}
func (c *Channel) id() int64 {
return c.ID
}
func (c *Channel) SubChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.subChans)
}
func (c *Channel) NestedChans() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(c.nestedChans)
}
func (c *Channel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()
return c.trace.copy()
}
type ChannelMetrics struct {
// The current connectivity state of the channel.
State atomic.Pointer[connectivity.State]
// The target this channel originally tried to connect to. May be absent
Target atomic.Pointer[string]
// The number of calls started on the channel.
CallsStarted atomic.Int64
// The number of calls that have completed with an OK status.
CallsSucceeded atomic.Int64
// The number of calls that have a completed with a non-OK status.
CallsFailed atomic.Int64
// The last time a call was started on the channel.
LastCallStartedTimestamp atomic.Int64
}
// CopyFrom copies the metrics in o to c. For testing only.
func (c *ChannelMetrics) CopyFrom(o *ChannelMetrics) {
c.State.Store(o.State.Load())
c.Target.Store(o.Target.Load())
c.CallsStarted.Store(o.CallsStarted.Load())
c.CallsSucceeded.Store(o.CallsSucceeded.Load())
c.CallsFailed.Store(o.CallsFailed.Load())
c.LastCallStartedTimestamp.Store(o.LastCallStartedTimestamp.Load())
}
// Equal returns true iff the metrics of c are the same as the metrics of o.
// For testing only.
func (c *ChannelMetrics) Equal(o any) bool {
oc, ok := o.(*ChannelMetrics)
if !ok {
return false
}
if (c.State.Load() == nil) != (oc.State.Load() == nil) {
return false
}
if c.State.Load() != nil && *c.State.Load() != *oc.State.Load() {
return false
}
if (c.Target.Load() == nil) != (oc.Target.Load() == nil) {
return false
}
if c.Target.Load() != nil && *c.Target.Load() != *oc.Target.Load() {
return false
}
return c.CallsStarted.Load() == oc.CallsStarted.Load() &&
c.CallsFailed.Load() == oc.CallsFailed.Load() &&
c.CallsSucceeded.Load() == oc.CallsSucceeded.Load() &&
c.LastCallStartedTimestamp.Load() == oc.LastCallStartedTimestamp.Load()
}
func strFromPointer(s *string) string {
if s == nil {
return ""
}
return *s
}
func (c *ChannelMetrics) String() string {
return fmt.Sprintf("State: %v, Target: %s, CallsStarted: %v, CallsSucceeded: %v, CallsFailed: %v, LastCallStartedTimestamp: %v",
c.State.Load(), strFromPointer(c.Target.Load()), c.CallsStarted.Load(), c.CallsSucceeded.Load(), c.CallsFailed.Load(), c.LastCallStartedTimestamp.Load(),
)
}
func NewChannelMetricForTesting(state connectivity.State, target string, started, succeeded, failed, timestamp int64) *ChannelMetrics {
c := &ChannelMetrics{}
c.State.Store(&state)
c.Target.Store(&target)
c.CallsStarted.Store(started)
c.CallsSucceeded.Store(succeeded)
c.CallsFailed.Store(failed)
c.LastCallStartedTimestamp.Store(timestamp)
return c
}
func (c *Channel) addChild(id int64, e entry) {
switch v := e.(type) {
case *SubChannel:
c.subChans[id] = v.RefName
case *Channel:
c.nestedChans[id] = v.RefName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a channel", id, e)
}
}
func (c *Channel) deleteChild(id int64) {
delete(c.subChans, id)
delete(c.nestedChans, id)
c.deleteSelfIfReady()
}
func (c *Channel) triggerDelete() {
c.closeCalled = true
c.deleteSelfIfReady()
}
func (c *Channel) getParentID() int64 {
if c.Parent == nil {
return -1
}
return c.Parent.ID
}
// deleteSelfFromTree tries to delete the channel from the channelz entry relation tree, which means
// deleting the channel reference from its parent's child list.
//
// In order for a channel to be deleted from the tree, it must meet the criteria that, removal of the
// corresponding grpc object has been invoked, and the channel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (c *Channel) deleteSelfFromTree() (deleted bool) {
if !c.closeCalled || len(c.subChans)+len(c.nestedChans) != 0 {
return false
}
// not top channel
if c.Parent != nil {
c.Parent.deleteChild(c.ID)
}
return true
}
// deleteSelfFromMap checks whether it is valid to delete the channel from the map, which means
// deleting the channel from channelz's tracking entirely. Users can no longer use id to query the
// channel, and its memory will be garbage collected.
//
// The trace reference count of the channel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (c *Channel) deleteSelfFromMap() (delete bool) {
return c.getTraceRefCount() == 0
}
// deleteSelfIfReady tries to delete the channel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the channel from the entry relation tree, i.e. delete the channel reference from its
// parent's child list.
// 2. delete the channel from the map, i.e. delete the channel entirely from channelz. Lookup by id
// will return entry not found error.
func (c *Channel) deleteSelfIfReady() {
if !c.deleteSelfFromTree() {
return
}
if !c.deleteSelfFromMap() {
return
}
db.deleteEntry(c.ID)
c.trace.clear()
}
func (c *Channel) getChannelTrace() *ChannelTrace {
return c.trace
}
func (c *Channel) incrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, 1)
}
func (c *Channel) decrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, -1)
}
func (c *Channel) getTraceRefCount() int {
i := atomic.LoadInt32(&c.traceRefCount)
return int(i)
}
func (c *Channel) getRefName() string {
return c.RefName
}

View File

@ -0,0 +1,402 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sort"
"sync"
"time"
)
// entry represents a node in the channelz database.
type entry interface {
// addChild adds a child e, whose channelz id is id to child list
addChild(id int64, e entry)
// deleteChild deletes a child with channelz id to be id from child list
deleteChild(id int64)
// triggerDelete tries to delete self from channelz database. However, if
// child list is not empty, then deletion from the database is on hold until
// the last child is deleted from database.
triggerDelete()
// deleteSelfIfReady check whether triggerDelete() has been called before,
// and whether child list is now empty. If both conditions are met, then
// delete self from database.
deleteSelfIfReady()
// getParentID returns parent ID of the entry. 0 value parent ID means no parent.
getParentID() int64
Entity
}
// channelMap is the storage data structure for channelz.
//
// Methods of channelMap can be divided into two categories with respect to
// locking.
//
// 1. Methods acquire the global lock.
// 2. Methods that can only be called when global lock is held.
//
// A second type of method need always to be called inside a first type of method.
type channelMap struct {
mu sync.RWMutex
topLevelChannels map[int64]struct{}
channels map[int64]*Channel
subChannels map[int64]*SubChannel
sockets map[int64]*Socket
servers map[int64]*Server
}
func newChannelMap() *channelMap {
return &channelMap{
topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*Channel),
subChannels: make(map[int64]*SubChannel),
sockets: make(map[int64]*Socket),
servers: make(map[int64]*Server),
}
}
func (c *channelMap) addServer(id int64, s *Server) {
c.mu.Lock()
defer c.mu.Unlock()
s.cm = c
c.servers[id] = s
}
func (c *channelMap) addChannel(id int64, cn *Channel, isTopChannel bool, pid int64) {
c.mu.Lock()
defer c.mu.Unlock()
cn.trace.cm = c
c.channels[id] = cn
if isTopChannel {
c.topLevelChannels[id] = struct{}{}
} else if p := c.channels[pid]; p != nil {
p.addChild(id, cn)
} else {
logger.Infof("channel %d references invalid parent ID %d", id, pid)
}
}
func (c *channelMap) addSubChannel(id int64, sc *SubChannel, pid int64) {
c.mu.Lock()
defer c.mu.Unlock()
sc.trace.cm = c
c.subChannels[id] = sc
if p := c.channels[pid]; p != nil {
p.addChild(id, sc)
} else {
logger.Infof("subchannel %d references invalid parent ID %d", id, pid)
}
}
func (c *channelMap) addSocket(s *Socket) {
c.mu.Lock()
defer c.mu.Unlock()
s.cm = c
c.sockets[s.ID] = s
if s.Parent == nil {
logger.Infof("normal socket %d has no parent", s.ID)
}
s.Parent.(entry).addChild(s.ID, s)
}
// removeEntry triggers the removal of an entry, which may not indeed delete the
// entry, if it has to wait on the deletion of its children and until no other
// entity's channel trace references it. It may lead to a chain of entry
// deletion. For example, deleting the last socket of a gracefully shutting down
// server will lead to the server being also deleted.
func (c *channelMap) removeEntry(id int64) {
c.mu.Lock()
defer c.mu.Unlock()
c.findEntry(id).triggerDelete()
}
// tracedChannel represents tracing operations which are present on both
// channels and subChannels.
type tracedChannel interface {
getChannelTrace() *ChannelTrace
incrTraceRefCount()
decrTraceRefCount()
getRefName() string
}
// c.mu must be held by the caller
func (c *channelMap) decrTraceRefCount(id int64) {
e := c.findEntry(id)
if v, ok := e.(tracedChannel); ok {
v.decrTraceRefCount()
e.deleteSelfIfReady()
}
}
// c.mu must be held by the caller.
func (c *channelMap) findEntry(id int64) entry {
if v, ok := c.channels[id]; ok {
return v
}
if v, ok := c.subChannels[id]; ok {
return v
}
if v, ok := c.servers[id]; ok {
return v
}
if v, ok := c.sockets[id]; ok {
return v
}
return &dummyEntry{idNotFound: id}
}
// c.mu must be held by the caller
//
// deleteEntry deletes an entry from the channelMap. Before calling this method,
// caller must check this entry is ready to be deleted, i.e removeEntry() has
// been called on it, and no children still exist.
func (c *channelMap) deleteEntry(id int64) entry {
if v, ok := c.sockets[id]; ok {
delete(c.sockets, id)
return v
}
if v, ok := c.subChannels[id]; ok {
delete(c.subChannels, id)
return v
}
if v, ok := c.channels[id]; ok {
delete(c.channels, id)
delete(c.topLevelChannels, id)
return v
}
if v, ok := c.servers[id]; ok {
delete(c.servers, id)
return v
}
return &dummyEntry{idNotFound: id}
}
func (c *channelMap) traceEvent(id int64, desc *TraceEvent) {
c.mu.Lock()
defer c.mu.Unlock()
child := c.findEntry(id)
childTC, ok := child.(tracedChannel)
if !ok {
return
}
childTC.getChannelTrace().append(&traceEvent{Desc: desc.Desc, Severity: desc.Severity, Timestamp: time.Now()})
if desc.Parent != nil {
parent := c.findEntry(child.getParentID())
var chanType RefChannelType
switch child.(type) {
case *Channel:
chanType = RefChannel
case *SubChannel:
chanType = RefSubChannel
}
if parentTC, ok := parent.(tracedChannel); ok {
parentTC.getChannelTrace().append(&traceEvent{
Desc: desc.Parent.Desc,
Severity: desc.Parent.Severity,
Timestamp: time.Now(),
RefID: id,
RefName: childTC.getRefName(),
RefType: chanType,
})
childTC.incrTraceRefCount()
}
}
}
type int64Slice []int64
func (s int64Slice) Len() int { return len(s) }
func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
func copyMap(m map[int64]string) map[int64]string {
n := make(map[int64]string)
for k, v := range m {
n[k] = v
}
return n
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *channelMap) getTopChannels(id int64, maxResults int) ([]*Channel, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
l := int64(len(c.topLevelChannels))
ids := make([]int64, 0, l)
for k := range c.topLevelChannels {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
end := true
var t []*Channel
for _, v := range ids[idx:] {
if len(t) == maxResults {
end = false
break
}
if cn, ok := c.channels[v]; ok {
t = append(t, cn)
}
}
return t, end
}
func (c *channelMap) getServers(id int64, maxResults int) ([]*Server, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
ids := make([]int64, 0, len(c.servers))
for k := range c.servers {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
end := true
var s []*Server
for _, v := range ids[idx:] {
if len(s) == maxResults {
end = false
break
}
if svr, ok := c.servers[v]; ok {
s = append(s, svr)
}
}
return s, end
}
func (c *channelMap) getServerSockets(id int64, startID int64, maxResults int) ([]*Socket, bool) {
if maxResults <= 0 {
maxResults = EntriesPerPage
}
c.mu.RLock()
defer c.mu.RUnlock()
svr, ok := c.servers[id]
if !ok {
// server with id doesn't exist.
return nil, true
}
svrskts := svr.sockets
ids := make([]int64, 0, len(svrskts))
sks := make([]*Socket, 0, min(len(svrskts), maxResults))
for k := range svrskts {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= startID })
end := true
for _, v := range ids[idx:] {
if len(sks) == maxResults {
end = false
break
}
if ns, ok := c.sockets[v]; ok {
sks = append(sks, ns)
}
}
return sks, end
}
func (c *channelMap) getChannel(id int64) *Channel {
c.mu.RLock()
defer c.mu.RUnlock()
return c.channels[id]
}
func (c *channelMap) getSubChannel(id int64) *SubChannel {
c.mu.RLock()
defer c.mu.RUnlock()
return c.subChannels[id]
}
func (c *channelMap) getSocket(id int64) *Socket {
c.mu.RLock()
defer c.mu.RUnlock()
return c.sockets[id]
}
func (c *channelMap) getServer(id int64) *Server {
c.mu.RLock()
defer c.mu.RUnlock()
return c.servers[id]
}
type dummyEntry struct {
// dummyEntry is a fake entry to handle entry not found case.
idNotFound int64
Entity
}
func (d *dummyEntry) String() string {
return fmt.Sprintf("non-existent entity #%d", d.idNotFound)
}
func (d *dummyEntry) ID() int64 { return d.idNotFound }
func (d *dummyEntry) addChild(id int64, e entry) {
// Note: It is possible for a normal program to reach here under race
// condition. For example, there could be a race between ClientConn.Close()
// info being propagated to addrConn and http2Client. ClientConn.Close()
// cancel the context and result in http2Client to error. The error info is
// then caught by transport monitor and before addrConn.tearDown() is called
// in side ClientConn.Close(). Therefore, the addrConn will create a new
// transport. And when registering the new transport in channelz, its parent
// addrConn could have already been torn down and deleted from channelz
// tracking, and thus reach the code here.
logger.Infof("attempt to add child of type %T with id %d to a parent (id=%d) that doesn't currently exist", e, id, d.idNotFound)
}
func (d *dummyEntry) deleteChild(id int64) {
// It is possible for a normal program to reach here under race condition.
// Refer to the example described in addChild().
logger.Infof("attempt to delete child with id %d from a parent (id=%d) that doesn't currently exist", id, d.idNotFound)
}
func (d *dummyEntry) triggerDelete() {
logger.Warningf("attempt to delete an entry (id=%d) that doesn't currently exist", d.idNotFound)
}
func (*dummyEntry) deleteSelfIfReady() {
// code should not reach here. deleteSelfIfReady is always called on an existing entry.
}
func (*dummyEntry) getParentID() int64 {
return 0
}
// Entity is implemented by all channelz types.
type Entity interface {
isEntity()
fmt.Stringer
id() int64
}

View File

@ -16,47 +16,32 @@
* *
*/ */
// Package channelz defines APIs for enabling channelz service, entry // Package channelz defines internal APIs for enabling channelz service, entry
// registration/deletion, and accessing channelz data. It also defines channelz // registration/deletion, and accessing channelz data. It also defines channelz
// metric struct formats. // metric struct formats.
//
// All APIs in this package are experimental.
package channelz package channelz
import ( import (
"errors"
"sort"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
) )
const (
defaultMaxTraceEntry int32 = 30
)
var ( var (
// IDGen is the global channelz entity ID generator. It should not be used // IDGen is the global channelz entity ID generator. It should not be used
// outside this package except by tests. // outside this package except by tests.
IDGen IDGenerator IDGen IDGenerator
db dbWrapper db *channelMap = newChannelMap()
// EntryPerPage defines the number of channelz entries to be shown on a web page. // EntriesPerPage defines the number of channelz entries to be shown on a web page.
EntryPerPage = int64(50) EntriesPerPage = 50
curState int32 curState int32
maxTraceEntry = defaultMaxTraceEntry
) )
// TurnOn turns on channelz data collection. // TurnOn turns on channelz data collection.
func TurnOn() { func TurnOn() {
if !IsOn() {
db.set(newChannelMap())
IDGen.Reset()
atomic.StoreInt32(&curState, 1) atomic.StoreInt32(&curState, 1)
}
} }
func init() { func init() {
@ -70,49 +55,15 @@ func IsOn() bool {
return atomic.LoadInt32(&curState) == 1 return atomic.LoadInt32(&curState) == 1
} }
// SetMaxTraceEntry sets maximum number of trace entry per entity (i.e. channel/subchannel).
// Setting it to 0 will disable channel tracing.
func SetMaxTraceEntry(i int32) {
atomic.StoreInt32(&maxTraceEntry, i)
}
// ResetMaxTraceEntryToDefault resets the maximum number of trace entry per entity to default.
func ResetMaxTraceEntryToDefault() {
atomic.StoreInt32(&maxTraceEntry, defaultMaxTraceEntry)
}
func getMaxTraceEntry() int {
i := atomic.LoadInt32(&maxTraceEntry)
return int(i)
}
// dbWarpper wraps around a reference to internal channelz data storage, and
// provide synchronized functionality to set and get the reference.
type dbWrapper struct {
mu sync.RWMutex
DB *channelMap
}
func (d *dbWrapper) set(db *channelMap) {
d.mu.Lock()
d.DB = db
d.mu.Unlock()
}
func (d *dbWrapper) get() *channelMap {
d.mu.RLock()
defer d.mu.RUnlock()
return d.DB
}
// GetTopChannels returns a slice of top channel's ChannelMetric, along with a // GetTopChannels returns a slice of top channel's ChannelMetric, along with a
// boolean indicating whether there's more top channels to be queried for. // boolean indicating whether there's more top channels to be queried for.
// //
// The arg id specifies that only top channel with id at or above it will be included // The arg id specifies that only top channel with id at or above it will be
// in the result. The returned slice is up to a length of the arg maxResults or // included in the result. The returned slice is up to a length of the arg
// EntryPerPage if maxResults is zero, and is sorted in ascending id order. // maxResults or EntriesPerPage if maxResults is zero, and is sorted in ascending
func GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) { // id order.
return db.get().GetTopChannels(id, maxResults) func GetTopChannels(id int64, maxResults int) ([]*Channel, bool) {
return db.getTopChannels(id, maxResults)
} }
// GetServers returns a slice of server's ServerMetric, along with a // GetServers returns a slice of server's ServerMetric, along with a
@ -120,73 +71,69 @@ func GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
// //
// The arg id specifies that only server with id at or above it will be included // The arg id specifies that only server with id at or above it will be included
// in the result. The returned slice is up to a length of the arg maxResults or // in the result. The returned slice is up to a length of the arg maxResults or
// EntryPerPage if maxResults is zero, and is sorted in ascending id order. // EntriesPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServers(id int64, maxResults int64) ([]*ServerMetric, bool) { func GetServers(id int64, maxResults int) ([]*Server, bool) {
return db.get().GetServers(id, maxResults) return db.getServers(id, maxResults)
} }
// GetServerSockets returns a slice of server's (identified by id) normal socket's // GetServerSockets returns a slice of server's (identified by id) normal socket's
// SocketMetric, along with a boolean indicating whether there's more sockets to // SocketMetrics, along with a boolean indicating whether there's more sockets to
// be queried for. // be queried for.
// //
// The arg startID specifies that only sockets with id at or above it will be // The arg startID specifies that only sockets with id at or above it will be
// included in the result. The returned slice is up to a length of the arg maxResults // included in the result. The returned slice is up to a length of the arg maxResults
// or EntryPerPage if maxResults is zero, and is sorted in ascending id order. // or EntriesPerPage if maxResults is zero, and is sorted in ascending id order.
func GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) { func GetServerSockets(id int64, startID int64, maxResults int) ([]*Socket, bool) {
return db.get().GetServerSockets(id, startID, maxResults) return db.getServerSockets(id, startID, maxResults)
} }
// GetChannel returns the ChannelMetric for the channel (identified by id). // GetChannel returns the Channel for the channel (identified by id).
func GetChannel(id int64) *ChannelMetric { func GetChannel(id int64) *Channel {
return db.get().GetChannel(id) return db.getChannel(id)
} }
// GetSubChannel returns the SubChannelMetric for the subchannel (identified by id). // GetSubChannel returns the SubChannel for the subchannel (identified by id).
func GetSubChannel(id int64) *SubChannelMetric { func GetSubChannel(id int64) *SubChannel {
return db.get().GetSubChannel(id) return db.getSubChannel(id)
} }
// GetSocket returns the SocketInternalMetric for the socket (identified by id). // GetSocket returns the Socket for the socket (identified by id).
func GetSocket(id int64) *SocketMetric { func GetSocket(id int64) *Socket {
return db.get().GetSocket(id) return db.getSocket(id)
} }
// GetServer returns the ServerMetric for the server (identified by id). // GetServer returns the ServerMetric for the server (identified by id).
func GetServer(id int64) *ServerMetric { func GetServer(id int64) *Server {
return db.get().GetServer(id) return db.getServer(id)
} }
// RegisterChannel registers the given channel c in the channelz database with // RegisterChannel registers the given channel c in the channelz database with
// ref as its reference name, and adds it to the child list of its parent // target as its target and reference name, and adds it to the child list of its
// (identified by pid). pid == nil means no parent. // parent. parent == nil means no parent.
// //
// Returns a unique channelz identifier assigned to this channel. // Returns a unique channelz identifier assigned to this channel.
// //
// If channelz is not turned ON, the channelz database is not mutated. // If channelz is not turned ON, the channelz database is not mutated.
func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier { func RegisterChannel(parent *Channel, target string) *Channel {
id := IDGen.genID() id := IDGen.genID()
var parent int64
isTopChannel := true
if pid != nil {
isTopChannel = false
parent = pid.Int()
}
if !IsOn() { if !IsOn() {
return newIdentifer(RefChannel, id, pid) return &Channel{ID: id}
} }
cn := &channel{ isTopChannel := parent == nil
refName: ref,
c: c, cn := &Channel{
subChans: make(map[int64]string), ID: id,
RefName: target,
nestedChans: make(map[int64]string), nestedChans: make(map[int64]string),
id: id, subChans: make(map[int64]string),
pid: parent, Parent: parent,
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())}, trace: &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())},
} }
db.get().addChannel(id, cn, isTopChannel, parent) cn.ChannelMetrics.Target.Store(&target)
return newIdentifer(RefChannel, id, pid) db.addChannel(id, cn, isTopChannel, cn.getParentID())
return cn
} }
// RegisterSubChannel registers the given subChannel c in the channelz database // RegisterSubChannel registers the given subChannel c in the channelz database
@ -196,555 +143,67 @@ func RegisterChannel(c Channel, pid *Identifier, ref string) *Identifier {
// Returns a unique channelz identifier assigned to this subChannel. // Returns a unique channelz identifier assigned to this subChannel.
// //
// If channelz is not turned ON, the channelz database is not mutated. // If channelz is not turned ON, the channelz database is not mutated.
func RegisterSubChannel(c Channel, pid *Identifier, ref string) (*Identifier, error) { func RegisterSubChannel(parent *Channel, ref string) *SubChannel {
if pid == nil {
return nil, errors.New("a SubChannel's parent id cannot be nil")
}
id := IDGen.genID() id := IDGen.genID()
if !IsOn() { sc := &SubChannel{
return newIdentifer(RefSubChannel, id, pid), nil ID: id,
RefName: ref,
parent: parent,
} }
sc := &subChannel{ if !IsOn() {
refName: ref, return sc
c: c,
sockets: make(map[int64]string),
id: id,
pid: pid.Int(),
trace: &channelTrace{createdTime: time.Now(), events: make([]*TraceEvent, 0, getMaxTraceEntry())},
} }
db.get().addSubChannel(id, sc, pid.Int())
return newIdentifer(RefSubChannel, id, pid), nil sc.sockets = make(map[int64]string)
sc.trace = &ChannelTrace{CreationTime: time.Now(), Events: make([]*traceEvent, 0, getMaxTraceEntry())}
db.addSubChannel(id, sc, parent.ID)
return sc
} }
// RegisterServer registers the given server s in channelz database. It returns // RegisterServer registers the given server s in channelz database. It returns
// the unique channelz tracking id assigned to this server. // the unique channelz tracking id assigned to this server.
// //
// If channelz is not turned ON, the channelz database is not mutated. // If channelz is not turned ON, the channelz database is not mutated.
func RegisterServer(s Server, ref string) *Identifier { func RegisterServer(ref string) *Server {
id := IDGen.genID() id := IDGen.genID()
if !IsOn() { if !IsOn() {
return newIdentifer(RefServer, id, nil) return &Server{ID: id}
} }
svr := &server{ svr := &Server{
refName: ref, RefName: ref,
s: s,
sockets: make(map[int64]string), sockets: make(map[int64]string),
listenSockets: make(map[int64]string), listenSockets: make(map[int64]string),
id: id, ID: id,
} }
db.get().addServer(id, svr) db.addServer(id, svr)
return newIdentifer(RefServer, id, nil) return svr
} }
// RegisterListenSocket registers the given listen socket s in channelz database // RegisterSocket registers the given normal socket s in channelz database
// with ref as its reference name, and add it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to
// this listen socket.
//
// If channelz is not turned ON, the channelz database is not mutated.
func RegisterListenSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) {
if pid == nil {
return nil, errors.New("a ListenSocket's parent id cannot be 0")
}
id := IDGen.genID()
if !IsOn() {
return newIdentifer(RefListenSocket, id, pid), nil
}
ls := &listenSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addListenSocket(id, ls, pid.Int())
return newIdentifer(RefListenSocket, id, pid), nil
}
// RegisterNormalSocket registers the given normal socket s in channelz database
// with ref as its reference name, and adds it to the child list of its parent // with ref as its reference name, and adds it to the child list of its parent
// (identified by pid). It returns the unique channelz tracking id assigned to // (identified by skt.Parent, which must be set). It returns the unique channelz
// this normal socket. // tracking id assigned to this normal socket.
// //
// If channelz is not turned ON, the channelz database is not mutated. // If channelz is not turned ON, the channelz database is not mutated.
func RegisterNormalSocket(s Socket, pid *Identifier, ref string) (*Identifier, error) { func RegisterSocket(skt *Socket) *Socket {
if pid == nil { skt.ID = IDGen.genID()
return nil, errors.New("a NormalSocket's parent id cannot be 0") if IsOn() {
db.addSocket(skt)
} }
id := IDGen.genID() return skt
if !IsOn() {
return newIdentifer(RefNormalSocket, id, pid), nil
}
ns := &normalSocket{refName: ref, s: s, id: id, pid: pid.Int()}
db.get().addNormalSocket(id, ns, pid.Int())
return newIdentifer(RefNormalSocket, id, pid), nil
} }
// RemoveEntry removes an entry with unique channelz tracking id to be id from // RemoveEntry removes an entry with unique channelz tracking id to be id from
// channelz database. // channelz database.
// //
// If channelz is not turned ON, this function is a no-op. // If channelz is not turned ON, this function is a no-op.
func RemoveEntry(id *Identifier) { func RemoveEntry(id int64) {
if !IsOn() { if !IsOn() {
return return
} }
db.get().removeEntry(id.Int()) db.removeEntry(id)
}
// TraceEventDesc is what the caller of AddTraceEvent should provide to describe
// the event to be added to the channel trace.
//
// The Parent field is optional. It is used for an event that will be recorded
// in the entity's parent trace.
type TraceEventDesc struct {
Desc string
Severity Severity
Parent *TraceEventDesc
}
// AddTraceEvent adds trace related to the entity with specified id, using the
// provided TraceEventDesc.
//
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, id *Identifier, depth int, desc *TraceEventDesc) {
// Log only the trace description associated with the bottom most entity.
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, withParens(id)+desc.Desc)
case CtWarning:
l.WarningDepth(depth+1, withParens(id)+desc.Desc)
case CtError:
l.ErrorDepth(depth+1, withParens(id)+desc.Desc)
}
if getMaxTraceEntry() == 0 {
return
}
if IsOn() {
db.get().traceEvent(id.Int(), desc)
}
}
// channelMap is the storage data structure for channelz.
// Methods of channelMap can be divided in two two categories with respect to locking.
// 1. Methods acquire the global lock.
// 2. Methods that can only be called when global lock is held.
// A second type of method need always to be called inside a first type of method.
type channelMap struct {
mu sync.RWMutex
topLevelChannels map[int64]struct{}
servers map[int64]*server
channels map[int64]*channel
subChannels map[int64]*subChannel
listenSockets map[int64]*listenSocket
normalSockets map[int64]*normalSocket
}
func newChannelMap() *channelMap {
return &channelMap{
topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*channel),
listenSockets: make(map[int64]*listenSocket),
normalSockets: make(map[int64]*normalSocket),
servers: make(map[int64]*server),
subChannels: make(map[int64]*subChannel),
}
}
func (c *channelMap) addServer(id int64, s *server) {
c.mu.Lock()
s.cm = c
c.servers[id] = s
c.mu.Unlock()
}
func (c *channelMap) addChannel(id int64, cn *channel, isTopChannel bool, pid int64) {
c.mu.Lock()
cn.cm = c
cn.trace.cm = c
c.channels[id] = cn
if isTopChannel {
c.topLevelChannels[id] = struct{}{}
} else {
c.findEntry(pid).addChild(id, cn)
}
c.mu.Unlock()
}
func (c *channelMap) addSubChannel(id int64, sc *subChannel, pid int64) {
c.mu.Lock()
sc.cm = c
sc.trace.cm = c
c.subChannels[id] = sc
c.findEntry(pid).addChild(id, sc)
c.mu.Unlock()
}
func (c *channelMap) addListenSocket(id int64, ls *listenSocket, pid int64) {
c.mu.Lock()
ls.cm = c
c.listenSockets[id] = ls
c.findEntry(pid).addChild(id, ls)
c.mu.Unlock()
}
func (c *channelMap) addNormalSocket(id int64, ns *normalSocket, pid int64) {
c.mu.Lock()
ns.cm = c
c.normalSockets[id] = ns
c.findEntry(pid).addChild(id, ns)
c.mu.Unlock()
}
// removeEntry triggers the removal of an entry, which may not indeed delete the entry, if it has to
// wait on the deletion of its children and until no other entity's channel trace references it.
// It may lead to a chain of entry deletion. For example, deleting the last socket of a gracefully
// shutting down server will lead to the server being also deleted.
func (c *channelMap) removeEntry(id int64) {
c.mu.Lock()
c.findEntry(id).triggerDelete()
c.mu.Unlock()
}
// c.mu must be held by the caller
func (c *channelMap) decrTraceRefCount(id int64) {
e := c.findEntry(id)
if v, ok := e.(tracedChannel); ok {
v.decrTraceRefCount()
e.deleteSelfIfReady()
}
}
// c.mu must be held by the caller.
func (c *channelMap) findEntry(id int64) entry {
var v entry
var ok bool
if v, ok = c.channels[id]; ok {
return v
}
if v, ok = c.subChannels[id]; ok {
return v
}
if v, ok = c.servers[id]; ok {
return v
}
if v, ok = c.listenSockets[id]; ok {
return v
}
if v, ok = c.normalSockets[id]; ok {
return v
}
return &dummyEntry{idNotFound: id}
}
// c.mu must be held by the caller
// deleteEntry simply deletes an entry from the channelMap. Before calling this
// method, caller must check this entry is ready to be deleted, i.e removeEntry()
// has been called on it, and no children still exist.
// Conditionals are ordered by the expected frequency of deletion of each entity
// type, in order to optimize performance.
func (c *channelMap) deleteEntry(id int64) {
var ok bool
if _, ok = c.normalSockets[id]; ok {
delete(c.normalSockets, id)
return
}
if _, ok = c.subChannels[id]; ok {
delete(c.subChannels, id)
return
}
if _, ok = c.channels[id]; ok {
delete(c.channels, id)
delete(c.topLevelChannels, id)
return
}
if _, ok = c.listenSockets[id]; ok {
delete(c.listenSockets, id)
return
}
if _, ok = c.servers[id]; ok {
delete(c.servers, id)
return
}
}
func (c *channelMap) traceEvent(id int64, desc *TraceEventDesc) {
c.mu.Lock()
child := c.findEntry(id)
childTC, ok := child.(tracedChannel)
if !ok {
c.mu.Unlock()
return
}
childTC.getChannelTrace().append(&TraceEvent{Desc: desc.Desc, Severity: desc.Severity, Timestamp: time.Now()})
if desc.Parent != nil {
parent := c.findEntry(child.getParentID())
var chanType RefChannelType
switch child.(type) {
case *channel:
chanType = RefChannel
case *subChannel:
chanType = RefSubChannel
}
if parentTC, ok := parent.(tracedChannel); ok {
parentTC.getChannelTrace().append(&TraceEvent{
Desc: desc.Parent.Desc,
Severity: desc.Parent.Severity,
Timestamp: time.Now(),
RefID: id,
RefName: childTC.getRefName(),
RefType: chanType,
})
childTC.incrTraceRefCount()
}
}
c.mu.Unlock()
}
type int64Slice []int64
func (s int64Slice) Len() int { return len(s) }
func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] }
func copyMap(m map[int64]string) map[int64]string {
n := make(map[int64]string)
for k, v := range m {
n[k] = v
}
return n
}
func min(a, b int64) int64 {
if a < b {
return a
}
return b
}
func (c *channelMap) GetTopChannels(id int64, maxResults int64) ([]*ChannelMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
c.mu.RLock()
l := int64(len(c.topLevelChannels))
ids := make([]int64, 0, l)
cns := make([]*channel, 0, min(l, maxResults))
for k := range c.topLevelChannels {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := int64(0)
var end bool
var t []*ChannelMetric
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if cn, ok := c.channels[v]; ok {
cns = append(cns, cn)
t = append(t, &ChannelMetric{
NestedChans: copyMap(cn.nestedChans),
SubChans: copyMap(cn.subChans),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, cn := range cns {
t[i].ChannelData = cn.c.ChannelzMetric()
t[i].ID = cn.id
t[i].RefName = cn.refName
t[i].Trace = cn.trace.dumpData()
}
return t, end
}
func (c *channelMap) GetServers(id, maxResults int64) ([]*ServerMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
c.mu.RLock()
l := int64(len(c.servers))
ids := make([]int64, 0, l)
ss := make([]*server, 0, min(l, maxResults))
for k := range c.servers {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= id })
count := int64(0)
var end bool
var s []*ServerMetric
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if svr, ok := c.servers[v]; ok {
ss = append(ss, svr)
s = append(s, &ServerMetric{
ListenSockets: copyMap(svr.listenSockets),
})
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
for i, svr := range ss {
s[i].ServerData = svr.s.ChannelzMetric()
s[i].ID = svr.id
s[i].RefName = svr.refName
}
return s, end
}
func (c *channelMap) GetServerSockets(id int64, startID int64, maxResults int64) ([]*SocketMetric, bool) {
if maxResults <= 0 {
maxResults = EntryPerPage
}
var svr *server
var ok bool
c.mu.RLock()
if svr, ok = c.servers[id]; !ok {
// server with id doesn't exist.
c.mu.RUnlock()
return nil, true
}
svrskts := svr.sockets
l := int64(len(svrskts))
ids := make([]int64, 0, l)
sks := make([]*normalSocket, 0, min(l, maxResults))
for k := range svrskts {
ids = append(ids, k)
}
sort.Sort(int64Slice(ids))
idx := sort.Search(len(ids), func(i int) bool { return ids[i] >= startID })
count := int64(0)
var end bool
for i, v := range ids[idx:] {
if count == maxResults {
break
}
if ns, ok := c.normalSockets[v]; ok {
sks = append(sks, ns)
count++
}
if i == len(ids[idx:])-1 {
end = true
break
}
}
c.mu.RUnlock()
if count == 0 {
end = true
}
s := make([]*SocketMetric, 0, len(sks))
for _, ns := range sks {
sm := &SocketMetric{}
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
s = append(s, sm)
}
return s, end
}
func (c *channelMap) GetChannel(id int64) *ChannelMetric {
cm := &ChannelMetric{}
var cn *channel
var ok bool
c.mu.RLock()
if cn, ok = c.channels[id]; !ok {
// channel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.NestedChans = copyMap(cn.nestedChans)
cm.SubChans = copyMap(cn.subChans)
// cn.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of cn.c when
// holding the lock to prevent potential data race.
chanCopy := cn.c
c.mu.RUnlock()
cm.ChannelData = chanCopy.ChannelzMetric()
cm.ID = cn.id
cm.RefName = cn.refName
cm.Trace = cn.trace.dumpData()
return cm
}
func (c *channelMap) GetSubChannel(id int64) *SubChannelMetric {
cm := &SubChannelMetric{}
var sc *subChannel
var ok bool
c.mu.RLock()
if sc, ok = c.subChannels[id]; !ok {
// subchannel with id doesn't exist.
c.mu.RUnlock()
return nil
}
cm.Sockets = copyMap(sc.sockets)
// sc.c can be set to &dummyChannel{} when deleteSelfFromMap is called. Save a copy of sc.c when
// holding the lock to prevent potential data race.
chanCopy := sc.c
c.mu.RUnlock()
cm.ChannelData = chanCopy.ChannelzMetric()
cm.ID = sc.id
cm.RefName = sc.refName
cm.Trace = sc.trace.dumpData()
return cm
}
func (c *channelMap) GetSocket(id int64) *SocketMetric {
sm := &SocketMetric{}
c.mu.RLock()
if ls, ok := c.listenSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ls.s.ChannelzMetric()
sm.ID = ls.id
sm.RefName = ls.refName
return sm
}
if ns, ok := c.normalSockets[id]; ok {
c.mu.RUnlock()
sm.SocketData = ns.s.ChannelzMetric()
sm.ID = ns.id
sm.RefName = ns.refName
return sm
}
c.mu.RUnlock()
return nil
}
func (c *channelMap) GetServer(id int64) *ServerMetric {
sm := &ServerMetric{}
var svr *server
var ok bool
c.mu.RLock()
if svr, ok = c.servers[id]; !ok {
c.mu.RUnlock()
return nil
}
sm.ListenSockets = copyMap(svr.listenSockets)
c.mu.RUnlock()
sm.ID = svr.id
sm.RefName = svr.refName
sm.ServerData = svr.s.ChannelzMetric()
return sm
} }
// IDGenerator is an incrementing atomic that tracks IDs for channelz entities. // IDGenerator is an incrementing atomic that tracks IDs for channelz entities.
@ -761,3 +220,11 @@ func (i *IDGenerator) Reset() {
func (i *IDGenerator) genID() int64 { func (i *IDGenerator) genID() int64 {
return atomic.AddInt64(&i.id, 1) return atomic.AddInt64(&i.id, 1)
} }
// Identifier is an opaque channelz identifier used to expose channelz symbols
// outside of grpc. Currently only implemented by Channel since no other
// types require exposure outside grpc.
type Identifier interface {
Entity
channelzIdentifier()
}

View File

@ -1,75 +0,0 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import "fmt"
// Identifier is an opaque identifier which uniquely identifies an entity in the
// channelz database.
type Identifier struct {
typ RefChannelType
id int64
str string
pid *Identifier
}
// Type returns the entity type corresponding to id.
func (id *Identifier) Type() RefChannelType {
return id.typ
}
// Int returns the integer identifier corresponding to id.
func (id *Identifier) Int() int64 {
return id.id
}
// String returns a string representation of the entity corresponding to id.
//
// This includes some information about the parent as well. Examples:
// Top-level channel: [Channel #channel-number]
// Nested channel: [Channel #parent-channel-number Channel #channel-number]
// Sub channel: [Channel #parent-channel SubChannel #subchannel-number]
func (id *Identifier) String() string {
return id.str
}
// Equal returns true if other is the same as id.
func (id *Identifier) Equal(other *Identifier) bool {
if (id != nil) != (other != nil) {
return false
}
if id == nil && other == nil {
return true
}
return id.typ == other.typ && id.id == other.id && id.pid == other.pid
}
// NewIdentifierForTesting returns a new opaque identifier to be used only for
// testing purposes.
func NewIdentifierForTesting(typ RefChannelType, id int64, pid *Identifier) *Identifier {
return newIdentifer(typ, id, pid)
}
func newIdentifer(typ RefChannelType, id int64, pid *Identifier) *Identifier {
str := fmt.Sprintf("%s #%d", typ, id)
if pid != nil {
str = fmt.Sprintf("%s %s", pid, str)
}
return &Identifier{typ: typ, id: id, str: str, pid: pid}
}

View File

@ -26,53 +26,49 @@ import (
var logger = grpclog.Component("channelz") var logger = grpclog.Component("channelz")
func withParens(id *Identifier) string {
return "[" + id.String() + "] "
}
// Info logs and adds a trace event if channelz is on. // Info logs and adds a trace event if channelz is on.
func Info(l grpclog.DepthLoggerV2, id *Identifier, args ...any) { func Info(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...), Desc: fmt.Sprint(args...),
Severity: CtInfo, Severity: CtInfo,
}) })
} }
// Infof logs and adds a trace event if channelz is on. // Infof logs and adds a trace event if channelz is on.
func Infof(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) { func Infof(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...), Desc: fmt.Sprintf(format, args...),
Severity: CtInfo, Severity: CtInfo,
}) })
} }
// Warning logs and adds a trace event if channelz is on. // Warning logs and adds a trace event if channelz is on.
func Warning(l grpclog.DepthLoggerV2, id *Identifier, args ...any) { func Warning(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...), Desc: fmt.Sprint(args...),
Severity: CtWarning, Severity: CtWarning,
}) })
} }
// Warningf logs and adds a trace event if channelz is on. // Warningf logs and adds a trace event if channelz is on.
func Warningf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) { func Warningf(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...), Desc: fmt.Sprintf(format, args...),
Severity: CtWarning, Severity: CtWarning,
}) })
} }
// Error logs and adds a trace event if channelz is on. // Error logs and adds a trace event if channelz is on.
func Error(l grpclog.DepthLoggerV2, id *Identifier, args ...any) { func Error(l grpclog.DepthLoggerV2, e Entity, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprint(args...), Desc: fmt.Sprint(args...),
Severity: CtError, Severity: CtError,
}) })
} }
// Errorf logs and adds a trace event if channelz is on. // Errorf logs and adds a trace event if channelz is on.
func Errorf(l grpclog.DepthLoggerV2, id *Identifier, format string, args ...any) { func Errorf(l grpclog.DepthLoggerV2, e Entity, format string, args ...any) {
AddTraceEvent(l, id, 1, &TraceEventDesc{ AddTraceEvent(l, e, 1, &TraceEvent{
Desc: fmt.Sprintf(format, args...), Desc: fmt.Sprintf(format, args...),
Severity: CtError, Severity: CtError,
}) })

View File

@ -0,0 +1,119 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
)
// Server is the channelz representation of a server.
type Server struct {
Entity
ID int64
RefName string
ServerMetrics ServerMetrics
closeCalled bool
sockets map[int64]string
listenSockets map[int64]string
cm *channelMap
}
// ServerMetrics defines a struct containing metrics for servers.
type ServerMetrics struct {
// The number of incoming calls started on the server.
CallsStarted atomic.Int64
// The number of incoming calls that have completed with an OK status.
CallsSucceeded atomic.Int64
// The number of incoming calls that have a completed with a non-OK status.
CallsFailed atomic.Int64
// The last time a call was started on the server.
LastCallStartedTimestamp atomic.Int64
}
// NewServerMetricsForTesting returns an initialized ServerMetrics.
func NewServerMetricsForTesting(started, succeeded, failed, timestamp int64) *ServerMetrics {
sm := &ServerMetrics{}
sm.CallsStarted.Store(started)
sm.CallsSucceeded.Store(succeeded)
sm.CallsFailed.Store(failed)
sm.LastCallStartedTimestamp.Store(timestamp)
return sm
}
func (sm *ServerMetrics) CopyFrom(o *ServerMetrics) {
sm.CallsStarted.Store(o.CallsStarted.Load())
sm.CallsSucceeded.Store(o.CallsSucceeded.Load())
sm.CallsFailed.Store(o.CallsFailed.Load())
sm.LastCallStartedTimestamp.Store(o.LastCallStartedTimestamp.Load())
}
// ListenSockets returns the listening sockets for s.
func (s *Server) ListenSockets() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(s.listenSockets)
}
// String returns a printable description of s.
func (s *Server) String() string {
return fmt.Sprintf("Server #%d", s.ID)
}
func (s *Server) id() int64 {
return s.ID
}
func (s *Server) addChild(id int64, e entry) {
switch v := e.(type) {
case *Socket:
switch v.SocketType {
case SocketTypeNormal:
s.sockets[id] = v.RefName
case SocketTypeListen:
s.listenSockets[id] = v.RefName
}
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a server", id, e)
}
}
func (s *Server) deleteChild(id int64) {
delete(s.sockets, id)
delete(s.listenSockets, id)
s.deleteSelfIfReady()
}
func (s *Server) triggerDelete() {
s.closeCalled = true
s.deleteSelfIfReady()
}
func (s *Server) deleteSelfIfReady() {
if !s.closeCalled || len(s.sockets)+len(s.listenSockets) != 0 {
return
}
s.cm.deleteEntry(s.ID)
}
func (s *Server) getParentID() int64 {
return 0
}

View File

@ -0,0 +1,130 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"net"
"sync/atomic"
"google.golang.org/grpc/credentials"
)
// SocketMetrics defines the struct that the implementor of Socket interface
// should return from ChannelzMetric().
type SocketMetrics struct {
// The number of streams that have been started.
StreamsStarted atomic.Int64
// The number of streams that have ended successfully:
// On client side, receiving frame with eos bit set.
// On server side, sending frame with eos bit set.
StreamsSucceeded atomic.Int64
// The number of streams that have ended unsuccessfully:
// On client side, termination without receiving frame with eos bit set.
// On server side, termination without sending frame with eos bit set.
StreamsFailed atomic.Int64
// The number of messages successfully sent on this socket.
MessagesSent atomic.Int64
MessagesReceived atomic.Int64
// The number of keep alives sent. This is typically implemented with HTTP/2
// ping messages.
KeepAlivesSent atomic.Int64
// The last time a stream was created by this endpoint. Usually unset for
// servers.
LastLocalStreamCreatedTimestamp atomic.Int64
// The last time a stream was created by the remote endpoint. Usually unset
// for clients.
LastRemoteStreamCreatedTimestamp atomic.Int64
// The last time a message was sent by this endpoint.
LastMessageSentTimestamp atomic.Int64
// The last time a message was received by this endpoint.
LastMessageReceivedTimestamp atomic.Int64
}
// EphemeralSocketMetrics are metrics that change rapidly and are tracked
// outside of channelz.
type EphemeralSocketMetrics struct {
// The amount of window, granted to the local endpoint by the remote endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
LocalFlowControlWindow int64
// The amount of window, granted to the remote endpoint by the local endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
RemoteFlowControlWindow int64
}
type SocketType string
const (
SocketTypeNormal = "NormalSocket"
SocketTypeListen = "ListenSocket"
)
type Socket struct {
Entity
SocketType SocketType
ID int64
Parent Entity
cm *channelMap
SocketMetrics SocketMetrics
EphemeralMetrics func() *EphemeralSocketMetrics
RefName string
// The locally bound address. Immutable.
LocalAddr net.Addr
// The remote bound address. May be absent. Immutable.
RemoteAddr net.Addr
// Optional, represents the name of the remote endpoint, if different than
// the original target name. Immutable.
RemoteName string
// Immutable.
SocketOptions *SocketOptionData
// Immutable.
Security credentials.ChannelzSecurityValue
}
func (ls *Socket) String() string {
return fmt.Sprintf("%s %s #%d", ls.Parent, ls.SocketType, ls.ID)
}
func (ls *Socket) id() int64 {
return ls.ID
}
func (ls *Socket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a listen socket", id, e)
}
func (ls *Socket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a listen socket", id)
}
func (ls *Socket) triggerDelete() {
ls.cm.deleteEntry(ls.ID)
ls.Parent.(entry).deleteChild(ls.ID)
}
func (ls *Socket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a listen socket")
}
func (ls *Socket) getParentID() int64 {
return ls.Parent.id()
}

View File

@ -0,0 +1,151 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync/atomic"
)
// SubChannel is the channelz representation of a subchannel.
type SubChannel struct {
Entity
// ID is the channelz id of this subchannel.
ID int64
// RefName is the human readable reference string of this subchannel.
RefName string
closeCalled bool
sockets map[int64]string
parent *Channel
trace *ChannelTrace
traceRefCount int32
ChannelMetrics ChannelMetrics
}
func (sc *SubChannel) String() string {
return fmt.Sprintf("%s SubChannel #%d", sc.parent, sc.ID)
}
func (sc *SubChannel) id() int64 {
return sc.ID
}
func (sc *SubChannel) Sockets() map[int64]string {
db.mu.RLock()
defer db.mu.RUnlock()
return copyMap(sc.sockets)
}
func (sc *SubChannel) Trace() *ChannelTrace {
db.mu.RLock()
defer db.mu.RUnlock()
return sc.trace.copy()
}
func (sc *SubChannel) addChild(id int64, e entry) {
if v, ok := e.(*Socket); ok && v.SocketType == SocketTypeNormal {
sc.sockets[id] = v.RefName
} else {
logger.Errorf("cannot add a child (id = %d) of type %T to a subChannel", id, e)
}
}
func (sc *SubChannel) deleteChild(id int64) {
delete(sc.sockets, id)
sc.deleteSelfIfReady()
}
func (sc *SubChannel) triggerDelete() {
sc.closeCalled = true
sc.deleteSelfIfReady()
}
func (sc *SubChannel) getParentID() int64 {
return sc.parent.ID
}
// deleteSelfFromTree tries to delete the subchannel from the channelz entry relation tree, which
// means deleting the subchannel reference from its parent's child list.
//
// In order for a subchannel to be deleted from the tree, it must meet the criteria that, removal of
// the corresponding grpc object has been invoked, and the subchannel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (sc *SubChannel) deleteSelfFromTree() (deleted bool) {
if !sc.closeCalled || len(sc.sockets) != 0 {
return false
}
sc.parent.deleteChild(sc.ID)
return true
}
// deleteSelfFromMap checks whether it is valid to delete the subchannel from the map, which means
// deleting the subchannel from channelz's tracking entirely. Users can no longer use id to query
// the subchannel, and its memory will be garbage collected.
//
// The trace reference count of the subchannel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (sc *SubChannel) deleteSelfFromMap() (delete bool) {
return sc.getTraceRefCount() == 0
}
// deleteSelfIfReady tries to delete the subchannel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the subchannel from the entry relation tree, i.e. delete the subchannel reference from
// its parent's child list.
// 2. delete the subchannel from the map, i.e. delete the subchannel entirely from channelz. Lookup
// by id will return entry not found error.
func (sc *SubChannel) deleteSelfIfReady() {
if !sc.deleteSelfFromTree() {
return
}
if !sc.deleteSelfFromMap() {
return
}
db.deleteEntry(sc.ID)
sc.trace.clear()
}
func (sc *SubChannel) getChannelTrace() *ChannelTrace {
return sc.trace
}
func (sc *SubChannel) incrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, 1)
}
func (sc *SubChannel) decrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, -1)
}
func (sc *SubChannel) getTraceRefCount() int {
i := atomic.LoadInt32(&sc.traceRefCount)
return int(i)
}
func (sc *SubChannel) getRefName() string {
return sc.RefName
}

View File

@ -49,3 +49,17 @@ func (s *SocketOptionData) Getsockopt(fd uintptr) {
s.TCPInfo = v s.TCPInfo = v
} }
} }
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(socket any) *SocketOptionData {
c, ok := socket.(syscall.Conn)
if !ok {
return nil
}
data := &SocketOptionData{}
if rawConn, err := c.SyscallConn(); err == nil {
rawConn.Control(data.Getsockopt)
return data
}
return nil
}

View File

@ -1,5 +1,4 @@
//go:build !linux //go:build !linux
// +build !linux
/* /*
* *
@ -41,3 +40,8 @@ func (s *SocketOptionData) Getsockopt(fd uintptr) {
logger.Warning("Channelz: socket options are not supported on non-linux environments") logger.Warning("Channelz: socket options are not supported on non-linux environments")
}) })
} }
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(c any) *SocketOptionData {
return nil
}

View File

@ -0,0 +1,204 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"fmt"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc/grpclog"
)
const (
defaultMaxTraceEntry int32 = 30
)
var maxTraceEntry = defaultMaxTraceEntry
// SetMaxTraceEntry sets maximum number of trace entries per entity (i.e.
// channel/subchannel). Setting it to 0 will disable channel tracing.
func SetMaxTraceEntry(i int32) {
atomic.StoreInt32(&maxTraceEntry, i)
}
// ResetMaxTraceEntryToDefault resets the maximum number of trace entries per
// entity to default.
func ResetMaxTraceEntryToDefault() {
atomic.StoreInt32(&maxTraceEntry, defaultMaxTraceEntry)
}
func getMaxTraceEntry() int {
i := atomic.LoadInt32(&maxTraceEntry)
return int(i)
}
// traceEvent is an internal representation of a single trace event
type traceEvent struct {
// Desc is a simple description of the trace event.
Desc string
// Severity states the severity of this trace event.
Severity Severity
// Timestamp is the event time.
Timestamp time.Time
// RefID is the id of the entity that gets referenced in the event. RefID is 0 if no other entity is
// involved in this event.
// e.g. SubChannel (id: 4[]) Created. --> RefID = 4, RefName = "" (inside [])
RefID int64
// RefName is the reference name for the entity that gets referenced in the event.
RefName string
// RefType indicates the referenced entity type, i.e Channel or SubChannel.
RefType RefChannelType
}
// TraceEvent is what the caller of AddTraceEvent should provide to describe the
// event to be added to the channel trace.
//
// The Parent field is optional. It is used for an event that will be recorded
// in the entity's parent trace.
type TraceEvent struct {
Desc string
Severity Severity
Parent *TraceEvent
}
type ChannelTrace struct {
cm *channelMap
clearCalled bool
CreationTime time.Time
EventNum int64
mu sync.Mutex
Events []*traceEvent
}
func (c *ChannelTrace) copy() *ChannelTrace {
return &ChannelTrace{
CreationTime: c.CreationTime,
EventNum: c.EventNum,
Events: append(([]*traceEvent)(nil), c.Events...),
}
}
func (c *ChannelTrace) append(e *traceEvent) {
c.mu.Lock()
if len(c.Events) == getMaxTraceEntry() {
del := c.Events[0]
c.Events = c.Events[1:]
if del.RefID != 0 {
// start recursive cleanup in a goroutine to not block the call originated from grpc.
go func() {
// need to acquire c.cm.mu lock to call the unlocked attemptCleanup func.
c.cm.mu.Lock()
c.cm.decrTraceRefCount(del.RefID)
c.cm.mu.Unlock()
}()
}
}
e.Timestamp = time.Now()
c.Events = append(c.Events, e)
c.EventNum++
c.mu.Unlock()
}
func (c *ChannelTrace) clear() {
if c.clearCalled {
return
}
c.clearCalled = true
c.mu.Lock()
for _, e := range c.Events {
if e.RefID != 0 {
// caller should have already held the c.cm.mu lock.
c.cm.decrTraceRefCount(e.RefID)
}
}
c.mu.Unlock()
}
// Severity is the severity level of a trace event.
// The canonical enumeration of all valid values is here:
// https://github.com/grpc/grpc-proto/blob/9b13d199cc0d4703c7ea26c9c330ba695866eb23/grpc/channelz/v1/channelz.proto#L126.
type Severity int
const (
// CtUnknown indicates unknown severity of a trace event.
CtUnknown Severity = iota
// CtInfo indicates info level severity of a trace event.
CtInfo
// CtWarning indicates warning level severity of a trace event.
CtWarning
// CtError indicates error level severity of a trace event.
CtError
)
// RefChannelType is the type of the entity being referenced in a trace event.
type RefChannelType int
const (
// RefUnknown indicates an unknown entity type, the zero value for this type.
RefUnknown RefChannelType = iota
// RefChannel indicates the referenced entity is a Channel.
RefChannel
// RefSubChannel indicates the referenced entity is a SubChannel.
RefSubChannel
// RefServer indicates the referenced entity is a Server.
RefServer
// RefListenSocket indicates the referenced entity is a ListenSocket.
RefListenSocket
// RefNormalSocket indicates the referenced entity is a NormalSocket.
RefNormalSocket
)
var refChannelTypeToString = map[RefChannelType]string{
RefUnknown: "Unknown",
RefChannel: "Channel",
RefSubChannel: "SubChannel",
RefServer: "Server",
RefListenSocket: "ListenSocket",
RefNormalSocket: "NormalSocket",
}
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}
// AddTraceEvent adds trace related to the entity with specified id, using the
// provided TraceEventDesc.
//
// If channelz is not turned ON, this will simply log the event descriptions.
func AddTraceEvent(l grpclog.DepthLoggerV2, e Entity, depth int, desc *TraceEvent) {
// Log only the trace description associated with the bottom most entity.
d := fmt.Sprintf("[%s]%s", e, desc.Desc)
switch desc.Severity {
case CtUnknown, CtInfo:
l.InfoDepth(depth+1, d)
case CtWarning:
l.WarningDepth(depth+1, d)
case CtError:
l.ErrorDepth(depth+1, d)
}
if getMaxTraceEntry() == 0 {
return
}
if IsOn() {
db.traceEvent(e.id(), desc)
}
}

View File

@ -1,727 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"net"
"sync"
"sync/atomic"
"time"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
)
// entry represents a node in the channelz database.
type entry interface {
// addChild adds a child e, whose channelz id is id to child list
addChild(id int64, e entry)
// deleteChild deletes a child with channelz id to be id from child list
deleteChild(id int64)
// triggerDelete tries to delete self from channelz database. However, if child
// list is not empty, then deletion from the database is on hold until the last
// child is deleted from database.
triggerDelete()
// deleteSelfIfReady check whether triggerDelete() has been called before, and whether child
// list is now empty. If both conditions are met, then delete self from database.
deleteSelfIfReady()
// getParentID returns parent ID of the entry. 0 value parent ID means no parent.
getParentID() int64
}
// dummyEntry is a fake entry to handle entry not found case.
type dummyEntry struct {
idNotFound int64
}
func (d *dummyEntry) addChild(id int64, e entry) {
// Note: It is possible for a normal program to reach here under race condition.
// For example, there could be a race between ClientConn.Close() info being propagated
// to addrConn and http2Client. ClientConn.Close() cancel the context and result
// in http2Client to error. The error info is then caught by transport monitor
// and before addrConn.tearDown() is called in side ClientConn.Close(). Therefore,
// the addrConn will create a new transport. And when registering the new transport in
// channelz, its parent addrConn could have already been torn down and deleted
// from channelz tracking, and thus reach the code here.
logger.Infof("attempt to add child of type %T with id %d to a parent (id=%d) that doesn't currently exist", e, id, d.idNotFound)
}
func (d *dummyEntry) deleteChild(id int64) {
// It is possible for a normal program to reach here under race condition.
// Refer to the example described in addChild().
logger.Infof("attempt to delete child with id %d from a parent (id=%d) that doesn't currently exist", id, d.idNotFound)
}
func (d *dummyEntry) triggerDelete() {
logger.Warningf("attempt to delete an entry (id=%d) that doesn't currently exist", d.idNotFound)
}
func (*dummyEntry) deleteSelfIfReady() {
// code should not reach here. deleteSelfIfReady is always called on an existing entry.
}
func (*dummyEntry) getParentID() int64 {
return 0
}
// ChannelMetric defines the info channelz provides for a specific Channel, which
// includes ChannelInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ChannelMetric struct {
// ID is the channelz id of this channel.
ID int64
// RefName is the human readable reference string of this channel.
RefName string
// ChannelData contains channel internal metric reported by the channel through
// ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this channel in the format of
// a map from nested channel channelz id to corresponding reference string.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this channel in the format of a
// map from subchannel channelz id to corresponding reference string.
SubChans map[int64]string
// Sockets tracks the socket type children of this channel in the format of a map
// from socket channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow channel having sockets directly,
// therefore, this is field is unused.
Sockets map[int64]string
// Trace contains the most recent traced events.
Trace *ChannelTrace
}
// SubChannelMetric defines the info channelz provides for a specific SubChannel,
// which includes ChannelInternalMetric and channelz-specific data, such as
// channelz id, child list, etc.
type SubChannelMetric struct {
// ID is the channelz id of this subchannel.
ID int64
// RefName is the human readable reference string of this subchannel.
RefName string
// ChannelData contains subchannel internal metric reported by the subchannel
// through ChannelzMetric().
ChannelData *ChannelInternalMetric
// NestedChans tracks the nested channel type children of this subchannel in the format of
// a map from nested channel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have nested channels
// as children, therefore, this field is unused.
NestedChans map[int64]string
// SubChans tracks the subchannel type children of this subchannel in the format of a
// map from subchannel channelz id to corresponding reference string.
// Note current grpc implementation doesn't allow subchannel to have subchannels
// as children, therefore, this field is unused.
SubChans map[int64]string
// Sockets tracks the socket type children of this subchannel in the format of a map
// from socket channelz id to corresponding reference string.
Sockets map[int64]string
// Trace contains the most recent traced events.
Trace *ChannelTrace
}
// ChannelInternalMetric defines the struct that the implementor of Channel interface
// should return from ChannelzMetric().
type ChannelInternalMetric struct {
// current connectivity state of the channel.
State connectivity.State
// The target this channel originally tried to connect to. May be absent
Target string
// The number of calls started on the channel.
CallsStarted int64
// The number of calls that have completed with an OK status.
CallsSucceeded int64
// The number of calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the channel.
LastCallStartedTimestamp time.Time
}
// ChannelTrace stores traced events on a channel/subchannel and related info.
type ChannelTrace struct {
// EventNum is the number of events that ever got traced (i.e. including those that have been deleted)
EventNum int64
// CreationTime is the creation time of the trace.
CreationTime time.Time
// Events stores the most recent trace events (up to $maxTraceEntry, newer event will overwrite the
// oldest one)
Events []*TraceEvent
}
// TraceEvent represent a single trace event
type TraceEvent struct {
// Desc is a simple description of the trace event.
Desc string
// Severity states the severity of this trace event.
Severity Severity
// Timestamp is the event time.
Timestamp time.Time
// RefID is the id of the entity that gets referenced in the event. RefID is 0 if no other entity is
// involved in this event.
// e.g. SubChannel (id: 4[]) Created. --> RefID = 4, RefName = "" (inside [])
RefID int64
// RefName is the reference name for the entity that gets referenced in the event.
RefName string
// RefType indicates the referenced entity type, i.e Channel or SubChannel.
RefType RefChannelType
}
// Channel is the interface that should be satisfied in order to be tracked by
// channelz as Channel or SubChannel.
type Channel interface {
ChannelzMetric() *ChannelInternalMetric
}
type dummyChannel struct{}
func (d *dummyChannel) ChannelzMetric() *ChannelInternalMetric {
return &ChannelInternalMetric{}
}
type channel struct {
refName string
c Channel
closeCalled bool
nestedChans map[int64]string
subChans map[int64]string
id int64
pid int64
cm *channelMap
trace *channelTrace
// traceRefCount is the number of trace events that reference this channel.
// Non-zero traceRefCount means the trace of this channel cannot be deleted.
traceRefCount int32
}
func (c *channel) addChild(id int64, e entry) {
switch v := e.(type) {
case *subChannel:
c.subChans[id] = v.refName
case *channel:
c.nestedChans[id] = v.refName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a channel", id, e)
}
}
func (c *channel) deleteChild(id int64) {
delete(c.subChans, id)
delete(c.nestedChans, id)
c.deleteSelfIfReady()
}
func (c *channel) triggerDelete() {
c.closeCalled = true
c.deleteSelfIfReady()
}
func (c *channel) getParentID() int64 {
return c.pid
}
// deleteSelfFromTree tries to delete the channel from the channelz entry relation tree, which means
// deleting the channel reference from its parent's child list.
//
// In order for a channel to be deleted from the tree, it must meet the criteria that, removal of the
// corresponding grpc object has been invoked, and the channel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (c *channel) deleteSelfFromTree() (deleted bool) {
if !c.closeCalled || len(c.subChans)+len(c.nestedChans) != 0 {
return false
}
// not top channel
if c.pid != 0 {
c.cm.findEntry(c.pid).deleteChild(c.id)
}
return true
}
// deleteSelfFromMap checks whether it is valid to delete the channel from the map, which means
// deleting the channel from channelz's tracking entirely. Users can no longer use id to query the
// channel, and its memory will be garbage collected.
//
// The trace reference count of the channel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (c *channel) deleteSelfFromMap() (delete bool) {
if c.getTraceRefCount() != 0 {
c.c = &dummyChannel{}
return false
}
return true
}
// deleteSelfIfReady tries to delete the channel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the channel from the entry relation tree, i.e. delete the channel reference from its
// parent's child list.
// 2. delete the channel from the map, i.e. delete the channel entirely from channelz. Lookup by id
// will return entry not found error.
func (c *channel) deleteSelfIfReady() {
if !c.deleteSelfFromTree() {
return
}
if !c.deleteSelfFromMap() {
return
}
c.cm.deleteEntry(c.id)
c.trace.clear()
}
func (c *channel) getChannelTrace() *channelTrace {
return c.trace
}
func (c *channel) incrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, 1)
}
func (c *channel) decrTraceRefCount() {
atomic.AddInt32(&c.traceRefCount, -1)
}
func (c *channel) getTraceRefCount() int {
i := atomic.LoadInt32(&c.traceRefCount)
return int(i)
}
func (c *channel) getRefName() string {
return c.refName
}
type subChannel struct {
refName string
c Channel
closeCalled bool
sockets map[int64]string
id int64
pid int64
cm *channelMap
trace *channelTrace
traceRefCount int32
}
func (sc *subChannel) addChild(id int64, e entry) {
if v, ok := e.(*normalSocket); ok {
sc.sockets[id] = v.refName
} else {
logger.Errorf("cannot add a child (id = %d) of type %T to a subChannel", id, e)
}
}
func (sc *subChannel) deleteChild(id int64) {
delete(sc.sockets, id)
sc.deleteSelfIfReady()
}
func (sc *subChannel) triggerDelete() {
sc.closeCalled = true
sc.deleteSelfIfReady()
}
func (sc *subChannel) getParentID() int64 {
return sc.pid
}
// deleteSelfFromTree tries to delete the subchannel from the channelz entry relation tree, which
// means deleting the subchannel reference from its parent's child list.
//
// In order for a subchannel to be deleted from the tree, it must meet the criteria that, removal of
// the corresponding grpc object has been invoked, and the subchannel does not have any children left.
//
// The returned boolean value indicates whether the channel has been successfully deleted from tree.
func (sc *subChannel) deleteSelfFromTree() (deleted bool) {
if !sc.closeCalled || len(sc.sockets) != 0 {
return false
}
sc.cm.findEntry(sc.pid).deleteChild(sc.id)
return true
}
// deleteSelfFromMap checks whether it is valid to delete the subchannel from the map, which means
// deleting the subchannel from channelz's tracking entirely. Users can no longer use id to query
// the subchannel, and its memory will be garbage collected.
//
// The trace reference count of the subchannel must be 0 in order to be deleted from the map. This is
// specified in the channel tracing gRFC that as long as some other trace has reference to an entity,
// the trace of the referenced entity must not be deleted. In order to release the resource allocated
// by grpc, the reference to the grpc object is reset to a dummy object.
//
// deleteSelfFromMap must be called after deleteSelfFromTree returns true.
//
// It returns a bool to indicate whether the channel can be safely deleted from map.
func (sc *subChannel) deleteSelfFromMap() (delete bool) {
if sc.getTraceRefCount() != 0 {
// free the grpc struct (i.e. addrConn)
sc.c = &dummyChannel{}
return false
}
return true
}
// deleteSelfIfReady tries to delete the subchannel itself from the channelz database.
// The delete process includes two steps:
// 1. delete the subchannel from the entry relation tree, i.e. delete the subchannel reference from
// its parent's child list.
// 2. delete the subchannel from the map, i.e. delete the subchannel entirely from channelz. Lookup
// by id will return entry not found error.
func (sc *subChannel) deleteSelfIfReady() {
if !sc.deleteSelfFromTree() {
return
}
if !sc.deleteSelfFromMap() {
return
}
sc.cm.deleteEntry(sc.id)
sc.trace.clear()
}
func (sc *subChannel) getChannelTrace() *channelTrace {
return sc.trace
}
func (sc *subChannel) incrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, 1)
}
func (sc *subChannel) decrTraceRefCount() {
atomic.AddInt32(&sc.traceRefCount, -1)
}
func (sc *subChannel) getTraceRefCount() int {
i := atomic.LoadInt32(&sc.traceRefCount)
return int(i)
}
func (sc *subChannel) getRefName() string {
return sc.refName
}
// SocketMetric defines the info channelz provides for a specific Socket, which
// includes SocketInternalMetric and channelz-specific data, such as channelz id, etc.
type SocketMetric struct {
// ID is the channelz id of this socket.
ID int64
// RefName is the human readable reference string of this socket.
RefName string
// SocketData contains socket internal metric reported by the socket through
// ChannelzMetric().
SocketData *SocketInternalMetric
}
// SocketInternalMetric defines the struct that the implementor of Socket interface
// should return from ChannelzMetric().
type SocketInternalMetric struct {
// The number of streams that have been started.
StreamsStarted int64
// The number of streams that have ended successfully:
// On client side, receiving frame with eos bit set.
// On server side, sending frame with eos bit set.
StreamsSucceeded int64
// The number of streams that have ended unsuccessfully:
// On client side, termination without receiving frame with eos bit set.
// On server side, termination without sending frame with eos bit set.
StreamsFailed int64
// The number of messages successfully sent on this socket.
MessagesSent int64
MessagesReceived int64
// The number of keep alives sent. This is typically implemented with HTTP/2
// ping messages.
KeepAlivesSent int64
// The last time a stream was created by this endpoint. Usually unset for
// servers.
LastLocalStreamCreatedTimestamp time.Time
// The last time a stream was created by the remote endpoint. Usually unset
// for clients.
LastRemoteStreamCreatedTimestamp time.Time
// The last time a message was sent by this endpoint.
LastMessageSentTimestamp time.Time
// The last time a message was received by this endpoint.
LastMessageReceivedTimestamp time.Time
// The amount of window, granted to the local endpoint by the remote endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
LocalFlowControlWindow int64
// The amount of window, granted to the remote endpoint by the local endpoint.
// This may be slightly out of date due to network latency. This does NOT
// include stream level or TCP level flow control info.
RemoteFlowControlWindow int64
// The locally bound address.
LocalAddr net.Addr
// The remote bound address. May be absent.
RemoteAddr net.Addr
// Optional, represents the name of the remote endpoint, if different than
// the original target name.
RemoteName string
SocketOptions *SocketOptionData
Security credentials.ChannelzSecurityValue
}
// Socket is the interface that should be satisfied in order to be tracked by
// channelz as Socket.
type Socket interface {
ChannelzMetric() *SocketInternalMetric
}
type listenSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ls *listenSocket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a listen socket", id, e)
}
func (ls *listenSocket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a listen socket", id)
}
func (ls *listenSocket) triggerDelete() {
ls.cm.deleteEntry(ls.id)
ls.cm.findEntry(ls.pid).deleteChild(ls.id)
}
func (ls *listenSocket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a listen socket")
}
func (ls *listenSocket) getParentID() int64 {
return ls.pid
}
type normalSocket struct {
refName string
s Socket
id int64
pid int64
cm *channelMap
}
func (ns *normalSocket) addChild(id int64, e entry) {
logger.Errorf("cannot add a child (id = %d) of type %T to a normal socket", id, e)
}
func (ns *normalSocket) deleteChild(id int64) {
logger.Errorf("cannot delete a child (id = %d) from a normal socket", id)
}
func (ns *normalSocket) triggerDelete() {
ns.cm.deleteEntry(ns.id)
ns.cm.findEntry(ns.pid).deleteChild(ns.id)
}
func (ns *normalSocket) deleteSelfIfReady() {
logger.Errorf("cannot call deleteSelfIfReady on a normal socket")
}
func (ns *normalSocket) getParentID() int64 {
return ns.pid
}
// ServerMetric defines the info channelz provides for a specific Server, which
// includes ServerInternalMetric and channelz-specific data, such as channelz id,
// child list, etc.
type ServerMetric struct {
// ID is the channelz id of this server.
ID int64
// RefName is the human readable reference string of this server.
RefName string
// ServerData contains server internal metric reported by the server through
// ChannelzMetric().
ServerData *ServerInternalMetric
// ListenSockets tracks the listener socket type children of this server in the
// format of a map from socket channelz id to corresponding reference string.
ListenSockets map[int64]string
}
// ServerInternalMetric defines the struct that the implementor of Server interface
// should return from ChannelzMetric().
type ServerInternalMetric struct {
// The number of incoming calls started on the server.
CallsStarted int64
// The number of incoming calls that have completed with an OK status.
CallsSucceeded int64
// The number of incoming calls that have a completed with a non-OK status.
CallsFailed int64
// The last time a call was started on the server.
LastCallStartedTimestamp time.Time
}
// Server is the interface to be satisfied in order to be tracked by channelz as
// Server.
type Server interface {
ChannelzMetric() *ServerInternalMetric
}
type server struct {
refName string
s Server
closeCalled bool
sockets map[int64]string
listenSockets map[int64]string
id int64
cm *channelMap
}
func (s *server) addChild(id int64, e entry) {
switch v := e.(type) {
case *normalSocket:
s.sockets[id] = v.refName
case *listenSocket:
s.listenSockets[id] = v.refName
default:
logger.Errorf("cannot add a child (id = %d) of type %T to a server", id, e)
}
}
func (s *server) deleteChild(id int64) {
delete(s.sockets, id)
delete(s.listenSockets, id)
s.deleteSelfIfReady()
}
func (s *server) triggerDelete() {
s.closeCalled = true
s.deleteSelfIfReady()
}
func (s *server) deleteSelfIfReady() {
if !s.closeCalled || len(s.sockets)+len(s.listenSockets) != 0 {
return
}
s.cm.deleteEntry(s.id)
}
func (s *server) getParentID() int64 {
return 0
}
type tracedChannel interface {
getChannelTrace() *channelTrace
incrTraceRefCount()
decrTraceRefCount()
getRefName() string
}
type channelTrace struct {
cm *channelMap
clearCalled bool
createdTime time.Time
eventCount int64
mu sync.Mutex
events []*TraceEvent
}
func (c *channelTrace) append(e *TraceEvent) {
c.mu.Lock()
if len(c.events) == getMaxTraceEntry() {
del := c.events[0]
c.events = c.events[1:]
if del.RefID != 0 {
// start recursive cleanup in a goroutine to not block the call originated from grpc.
go func() {
// need to acquire c.cm.mu lock to call the unlocked attemptCleanup func.
c.cm.mu.Lock()
c.cm.decrTraceRefCount(del.RefID)
c.cm.mu.Unlock()
}()
}
}
e.Timestamp = time.Now()
c.events = append(c.events, e)
c.eventCount++
c.mu.Unlock()
}
func (c *channelTrace) clear() {
if c.clearCalled {
return
}
c.clearCalled = true
c.mu.Lock()
for _, e := range c.events {
if e.RefID != 0 {
// caller should have already held the c.cm.mu lock.
c.cm.decrTraceRefCount(e.RefID)
}
}
c.mu.Unlock()
}
// Severity is the severity level of a trace event.
// The canonical enumeration of all valid values is here:
// https://github.com/grpc/grpc-proto/blob/9b13d199cc0d4703c7ea26c9c330ba695866eb23/grpc/channelz/v1/channelz.proto#L126.
type Severity int
const (
// CtUnknown indicates unknown severity of a trace event.
CtUnknown Severity = iota
// CtInfo indicates info level severity of a trace event.
CtInfo
// CtWarning indicates warning level severity of a trace event.
CtWarning
// CtError indicates error level severity of a trace event.
CtError
)
// RefChannelType is the type of the entity being referenced in a trace event.
type RefChannelType int
const (
// RefUnknown indicates an unknown entity type, the zero value for this type.
RefUnknown RefChannelType = iota
// RefChannel indicates the referenced entity is a Channel.
RefChannel
// RefSubChannel indicates the referenced entity is a SubChannel.
RefSubChannel
// RefServer indicates the referenced entity is a Server.
RefServer
// RefListenSocket indicates the referenced entity is a ListenSocket.
RefListenSocket
// RefNormalSocket indicates the referenced entity is a NormalSocket.
RefNormalSocket
)
var refChannelTypeToString = map[RefChannelType]string{
RefUnknown: "Unknown",
RefChannel: "Channel",
RefSubChannel: "SubChannel",
RefServer: "Server",
RefListenSocket: "ListenSocket",
RefNormalSocket: "NormalSocket",
}
func (r RefChannelType) String() string {
return refChannelTypeToString[r]
}
func (c *channelTrace) dumpData() *ChannelTrace {
c.mu.Lock()
ct := &ChannelTrace{EventNum: c.eventCount, CreationTime: c.createdTime}
ct.Events = c.events[:len(c.events)]
c.mu.Unlock()
return ct
}

View File

@ -1,37 +0,0 @@
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package channelz
import (
"syscall"
)
// GetSocketOption gets the socket option info of the conn.
func GetSocketOption(socket any) *SocketOptionData {
c, ok := socket.(syscall.Conn)
if !ok {
return nil
}
data := &SocketOptionData{}
if rawConn, err := c.SyscallConn(); err == nil {
rawConn.Control(data.Getsockopt)
return data
}
return nil
}

View File

@ -28,9 +28,6 @@ import (
var ( var (
// TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). // TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false").
TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true) TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true)
// AdvertiseCompressors is set if registered compressor should be advertised
// ("GRPC_GO_ADVERTISE_COMPRESSORS" is not "false").
AdvertiseCompressors = boolFromEnv("GRPC_GO_ADVERTISE_COMPRESSORS", true)
// RingHashCap indicates the maximum ring size which defaults to 4096 // RingHashCap indicates the maximum ring size which defaults to 4096
// entries but may be overridden by setting the environment variable // entries but may be overridden by setting the environment variable
// "GRPC_RING_HASH_CAP". This does not override the default bounds // "GRPC_RING_HASH_CAP". This does not override the default bounds
@ -43,6 +40,16 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS // ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed. // handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
// XDSFallbackSupport is the env variable that controls whether support for
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", false)
) )
func boolFromEnv(envVar string, def bool) bool { func boolFromEnv(envVar string, def bool) bool {

View File

@ -18,11 +18,11 @@
package internal package internal
var ( var (
// WithRecvBufferPool is implemented by the grpc package and returns a dial // WithBufferPool is implemented by the grpc package and returns a dial
// option to configure a shared buffer pool for a grpc.ClientConn. // option to configure a shared buffer pool for a grpc.ClientConn.
WithRecvBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption WithBufferPool any // func (grpc.SharedBufferPool) grpc.DialOption
// RecvBufferPool is implemented by the grpc package and returns a server // BufferPool is implemented by the grpc package and returns a server
// option to configure a shared buffer pool for a grpc.Server. // option to configure a shared buffer pool for a grpc.Server.
RecvBufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption
) )

View File

@ -16,17 +16,21 @@
* *
*/ */
// Package grpclog provides logging functionality for internal gRPC packages,
// outside of the functionality provided by the external `grpclog` package.
package grpclog package grpclog
import ( import (
"fmt" "fmt"
"google.golang.org/grpc/grpclog"
) )
// PrefixLogger does logging with a prefix. // PrefixLogger does logging with a prefix.
// //
// Logging method on a nil logs without any prefix. // Logging method on a nil logs without any prefix.
type PrefixLogger struct { type PrefixLogger struct {
logger DepthLoggerV2 logger grpclog.DepthLoggerV2
prefix string prefix string
} }
@ -38,7 +42,7 @@ func (pl *PrefixLogger) Infof(format string, args ...any) {
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...)) pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return return
} }
InfoDepth(1, fmt.Sprintf(format, args...)) grpclog.InfoDepth(1, fmt.Sprintf(format, args...))
} }
// Warningf does warning logging. // Warningf does warning logging.
@ -48,7 +52,7 @@ func (pl *PrefixLogger) Warningf(format string, args ...any) {
pl.logger.WarningDepth(1, fmt.Sprintf(format, args...)) pl.logger.WarningDepth(1, fmt.Sprintf(format, args...))
return return
} }
WarningDepth(1, fmt.Sprintf(format, args...)) grpclog.WarningDepth(1, fmt.Sprintf(format, args...))
} }
// Errorf does error logging. // Errorf does error logging.
@ -58,36 +62,18 @@ func (pl *PrefixLogger) Errorf(format string, args ...any) {
pl.logger.ErrorDepth(1, fmt.Sprintf(format, args...)) pl.logger.ErrorDepth(1, fmt.Sprintf(format, args...))
return return
} }
ErrorDepth(1, fmt.Sprintf(format, args...)) grpclog.ErrorDepth(1, fmt.Sprintf(format, args...))
}
// Debugf does info logging at verbose level 2.
func (pl *PrefixLogger) Debugf(format string, args ...any) {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe
// rewrite PrefixLogger a little to ensure that we don't use the global
// `Logger` here, and instead use the `logger` field.
if !Logger.V(2) {
return
}
if pl != nil {
// Handle nil, so the tests can pass in a nil logger.
format = pl.prefix + format
pl.logger.InfoDepth(1, fmt.Sprintf(format, args...))
return
}
InfoDepth(1, fmt.Sprintf(format, args...))
} }
// V reports whether verbosity level l is at least the requested verbose level. // V reports whether verbosity level l is at least the requested verbose level.
func (pl *PrefixLogger) V(l int) bool { func (pl *PrefixLogger) V(l int) bool {
// TODO(6044): Refactor interfaces LoggerV2 and DepthLogger, and maybe if pl != nil {
// rewrite PrefixLogger a little to ensure that we don't use the global return pl.logger.V(l)
// `Logger` here, and instead use the `logger` field. }
return Logger.V(l) return true
} }
// NewPrefixLogger creates a prefix logger with the given prefix. // NewPrefixLogger creates a prefix logger with the given prefix.
func NewPrefixLogger(logger DepthLoggerV2, prefix string) *PrefixLogger { func NewPrefixLogger(logger grpclog.DepthLoggerV2, prefix string) *PrefixLogger {
return &PrefixLogger{logger: logger, prefix: prefix} return &PrefixLogger{logger: logger, prefix: prefix}
} }

View File

@ -1,100 +0,0 @@
//go:build !go1.21
// TODO: when this file is deleted (after Go 1.20 support is dropped), delete
// all of grpcrand and call the rand package directly.
/*
*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import (
"math/rand"
"sync"
"time"
)
var (
r = rand.New(rand.NewSource(time.Now().UnixNano()))
mu sync.Mutex
)
// Int implements rand.Int on the grpcrand global source.
func Int() int {
mu.Lock()
defer mu.Unlock()
return r.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
mu.Lock()
defer mu.Unlock()
return r.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
mu.Lock()
defer mu.Unlock()
return r.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
mu.Lock()
defer mu.Unlock()
return r.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
mu.Lock()
defer mu.Unlock()
return r.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
mu.Lock()
defer mu.Unlock()
return r.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
mu.Lock()
defer mu.Unlock()
return r.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
mu.Lock()
defer mu.Unlock()
return r.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
mu.Lock()
defer mu.Unlock()
r.Shuffle(n, f)
}

View File

@ -1,73 +0,0 @@
//go:build go1.21
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package grpcrand implements math/rand functions in a concurrent-safe way
// with a global random source, independent of math/rand's global source.
package grpcrand
import "math/rand"
// This implementation will be used for Go version 1.21 or newer.
// For older versions, the original implementation with mutex will be used.
// Int implements rand.Int on the grpcrand global source.
func Int() int {
return rand.Int()
}
// Int63n implements rand.Int63n on the grpcrand global source.
func Int63n(n int64) int64 {
return rand.Int63n(n)
}
// Intn implements rand.Intn on the grpcrand global source.
func Intn(n int) int {
return rand.Intn(n)
}
// Int31n implements rand.Int31n on the grpcrand global source.
func Int31n(n int32) int32 {
return rand.Int31n(n)
}
// Float64 implements rand.Float64 on the grpcrand global source.
func Float64() float64 {
return rand.Float64()
}
// Uint64 implements rand.Uint64 on the grpcrand global source.
func Uint64() uint64 {
return rand.Uint64()
}
// Uint32 implements rand.Uint32 on the grpcrand global source.
func Uint32() uint32 {
return rand.Uint32()
}
// ExpFloat64 implements rand.ExpFloat64 on the grpcrand global source.
func ExpFloat64() float64 {
return rand.ExpFloat64()
}
// Shuffle implements rand.Shuffle on the grpcrand global source.
var Shuffle = func(n int, f func(int, int)) {
rand.Shuffle(n, f)
}

View File

@ -53,16 +53,28 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
return cs return cs
} }
// Schedule adds a callback to be scheduled after existing callbacks are run. // TrySchedule tries to schedules the provided callback function f to be
// executed in the order it was added. This is a best-effort operation. If the
// context passed to NewCallbackSerializer was canceled before this method is
// called, the callback will not be scheduled.
// //
// Callbacks are expected to honor the context when performing any blocking // Callbacks are expected to honor the context when performing any blocking
// operations, and should return early when the context is canceled. // operations, and should return early when the context is canceled.
func (cs *CallbackSerializer) TrySchedule(f func(ctx context.Context)) {
cs.callbacks.Put(f)
}
// ScheduleOr schedules the provided callback function f to be executed in the
// order it was added. If the context passed to NewCallbackSerializer has been
// canceled before this method is called, the onFailure callback will be
// executed inline instead.
// //
// Return value indicates if the callback was successfully added to the list of // Callbacks are expected to honor the context when performing any blocking
// callbacks to be executed by the serializer. It is not possible to add // operations, and should return early when the context is canceled.
// callbacks once the context passed to NewCallbackSerializer is cancelled. func (cs *CallbackSerializer) ScheduleOr(f func(ctx context.Context), onFailure func()) {
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool { if cs.callbacks.Put(f) != nil {
return cs.callbacks.Put(f) == nil onFailure()
}
} }
func (cs *CallbackSerializer) run(ctx context.Context) { func (cs *CallbackSerializer) run(ctx context.Context) {

View File

@ -77,7 +77,7 @@ func (ps *PubSub) Subscribe(sub Subscriber) (cancel func()) {
if ps.msg != nil { if ps.msg != nil {
msg := ps.msg msg := ps.msg
ps.cs.Schedule(func(context.Context) { ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock() ps.mu.Lock()
defer ps.mu.Unlock() defer ps.mu.Unlock()
if !ps.subscribers[sub] { if !ps.subscribers[sub] {
@ -103,7 +103,7 @@ func (ps *PubSub) Publish(msg any) {
ps.msg = msg ps.msg = msg
for sub := range ps.subscribers { for sub := range ps.subscribers {
s := sub s := sub
ps.cs.Schedule(func(context.Context) { ps.cs.TrySchedule(func(context.Context) {
ps.mu.Lock() ps.mu.Lock()
defer ps.mu.Unlock() defer ps.mu.Unlock()
if !ps.subscribers[s] { if !ps.subscribers[s] {

View File

@ -20,8 +20,6 @@ package grpcutil
import ( import (
"strings" "strings"
"google.golang.org/grpc/internal/envconfig"
) )
// RegisteredCompressorNames holds names of the registered compressors. // RegisteredCompressorNames holds names of the registered compressors.
@ -40,8 +38,5 @@ func IsCompressorNameRegistered(name string) bool {
// RegisteredCompressors returns a string of registered compressor names // RegisteredCompressors returns a string of registered compressor names
// separated by comma. // separated by comma.
func RegisteredCompressors() string { func RegisteredCompressors() string {
if !envconfig.AdvertiseCompressors {
return ""
}
return strings.Join(RegisteredCompressorNames, ",") return strings.Join(RegisteredCompressorNames, ",")
} }

View File

@ -106,6 +106,14 @@ var (
// This is used in the 1.0 release of gcp/observability, and thus must not be // This is used in the 1.0 release of gcp/observability, and thus must not be
// deleted or changed. // deleted or changed.
ClearGlobalDialOptions func() ClearGlobalDialOptions func()
// AddGlobalPerTargetDialOptions adds a PerTargetDialOption that will be
// configured for newly created ClientConns.
AddGlobalPerTargetDialOptions any // func (opt any)
// ClearGlobalPerTargetDialOptions clears the slice of global late apply
// dial options.
ClearGlobalPerTargetDialOptions func()
// JoinDialOptions combines the dial options passed as arguments into a // JoinDialOptions combines the dial options passed as arguments into a
// single dial option. // single dial option.
JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption
@ -126,7 +134,8 @@ var (
// deleted or changed. // deleted or changed.
BinaryLogger any // func(binarylog.Logger) grpc.ServerOption BinaryLogger any // func(binarylog.Logger) grpc.ServerOption
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a provided grpc.ClientConn // SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a
// provided grpc.ClientConn.
SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber) SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber)
// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using // NewXDSResolverWithConfigForTesting creates a new xds resolver builder using
@ -184,21 +193,45 @@ var (
ChannelzTurnOffForTesting func() ChannelzTurnOffForTesting func()
// TriggerXDSResourceNameNotFoundForTesting triggers the resource-not-found // TriggerXDSResourceNotFoundForTesting causes the provided xDS Client to
// error for a given resource type and name. This is usually triggered when // invoke resource-not-found error for the given resource type and name.
// the associated watch timer fires. For testing purposes, having this TriggerXDSResourceNotFoundForTesting any // func(xdsclient.XDSClient, xdsresource.Type, string) error
// function makes events more predictable than relying on timer events.
TriggerXDSResourceNameNotFoundForTesting any // func(func(xdsresource.Type, string), string, string) error
// TriggerXDSResourceNotFoundClient invokes the testing xDS Client singleton // FromOutgoingContextRaw returns the un-merged, intermediary contents of
// to invoke resource not found for a resource type name and resource name. // metadata.rawMD.
TriggerXDSResourceNameNotFoundClient any // func(string, string) error
// FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD.
FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool) FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool)
// UserSetDefaultScheme is set to true if the user has overridden the
// default resolver scheme.
UserSetDefaultScheme bool = false
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j.
ShuffleAddressListForTesting any // func(n int, swap func(i, j int))
// ConnectedAddress returns the connected address for a SubConnState. The
// address is only valid if the state is READY.
ConnectedAddress any // func (scs SubConnState) resolver.Address
// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
// SnapshotMetricRegistryForTesting snapshots the global data of the metric
// registry. Returns a cleanup function that sets the metric registry to its
// original state. Only called in testing functions.
SnapshotMetricRegistryForTesting func() func()
// SetDefaultBufferPoolForTesting updates the default buffer pool, for
// testing purposes.
SetDefaultBufferPoolForTesting any // func(mem.BufferPool)
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
// testing purposes.
SetBufferPoolingThresholdForTesting any // func(int)
) )
// HealthChecker defines the signature of the client-side LB channel health checking function. // HealthChecker defines the signature of the client-side LB channel health
// checking function.
// //
// The implementation is expected to create a health checking RPC stream by // The implementation is expected to create a health checking RPC stream by
// calling newStream(), watch for the health status of serviceName, and report // calling newStream(), watch for the health status of serviceName, and report

View File

@ -24,9 +24,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
protov1 "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
protov2 "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt"
) )
const jsonIndent = " " const jsonIndent = " "
@ -35,21 +34,14 @@ const jsonIndent = " "
// //
// If marshal fails, it falls back to fmt.Sprintf("%+v"). // If marshal fails, it falls back to fmt.Sprintf("%+v").
func ToJSON(e any) string { func ToJSON(e any) string {
switch ee := e.(type) { if ee, ok := e.(protoadapt.MessageV1); ok {
case protov1.Message: e = protoadapt.MessageV2Of(ee)
mm := protojson.MarshalOptions{Indent: jsonIndent}
ret, err := mm.Marshal(protov1.MessageV2(ee))
if err != nil {
// This may fail for proto.Anys, e.g. for xDS v2, LDS, the v2
// messages are not imported, and this will fail because the message
// is not found.
return fmt.Sprintf("%+v", ee)
} }
return string(ret)
case protov2.Message: if ee, ok := e.(protoadapt.MessageV2); ok {
mm := protojson.MarshalOptions{ mm := protojson.MarshalOptions{
Multiline: true,
Indent: jsonIndent, Indent: jsonIndent,
Multiline: true,
} }
ret, err := mm.Marshal(ee) ret, err := mm.Marshal(ee)
if err != nil { if err != nil {
@ -59,13 +51,13 @@ func ToJSON(e any) string {
return fmt.Sprintf("%+v", ee) return fmt.Sprintf("%+v", ee)
} }
return string(ret) return string(ret)
default: }
ret, err := json.MarshalIndent(ee, "", jsonIndent)
ret, err := json.MarshalIndent(e, "", jsonIndent)
if err != nil { if err != nil {
return fmt.Sprintf("%+v", ee) return fmt.Sprintf("%+v", e)
} }
return string(ret) return string(ret)
}
} }
// FormatJSON formats the input json bytes with indentation. // FormatJSON formats the input json bytes with indentation.

View File

@ -24,6 +24,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math/rand"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -35,21 +36,35 @@ import (
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/resolver/dns/internal" "google.golang.org/grpc/internal/resolver/dns/internal"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB var (
// addresses from SRV records. Must not be changed after init time. // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
var EnableSRVLookups = false // addresses from SRV records. Must not be changed after init time.
EnableSRVLookups = false
var logger = grpclog.Component("dns") // MinResolutionInterval is the minimum interval at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionInterval = 30 * time.Second
// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
// If the timeout expires before a response is received, the request will be canceled.
//
// It is recommended to set this value at application startup. Avoid modifying this variable
// after initialization as it's not thread-safe for concurrent modification.
ResolvingTimeout = 30 * time.Second
logger = grpclog.Component("dns")
)
func init() { func init() {
resolver.Register(NewBuilder()) resolver.Register(NewBuilder())
internal.TimeAfterFunc = time.After internal.TimeAfterFunc = time.After
internal.TimeNowFunc = time.Now
internal.TimeUntilFunc = time.Until
internal.NewNetResolver = newNetResolver internal.NewNetResolver = newNetResolver
internal.AddressDialer = addressDialer internal.AddressDialer = addressDialer
} }
@ -196,12 +211,12 @@ func (d *dnsResolver) watcher() {
err = d.cc.UpdateState(*state) err = d.cc.UpdateState(*state)
} }
var waitTime time.Duration var nextResolutionTime time.Time
if err == nil { if err == nil {
// Success resolving, wait for the next ResolveNow. However, also wait 30 // Success resolving, wait for the next ResolveNow. However, also wait 30
// seconds at the very least to prevent constantly re-resolving. // seconds at the very least to prevent constantly re-resolving.
backoffIndex = 1 backoffIndex = 1
waitTime = internal.MinResolutionRate nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
select { select {
case <-d.ctx.Done(): case <-d.ctx.Done():
return return
@ -210,29 +225,29 @@ func (d *dnsResolver) watcher() {
} else { } else {
// Poll on an error found in DNS Resolver or an error received from // Poll on an error found in DNS Resolver or an error received from
// ClientConn. // ClientConn.
waitTime = backoff.DefaultExponential.Backoff(backoffIndex) nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
backoffIndex++ backoffIndex++
} }
select { select {
case <-d.ctx.Done(): case <-d.ctx.Done():
return return
case <-internal.TimeAfterFunc(waitTime): case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
} }
} }
} }
func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) { func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
if !EnableSRVLookups { if !EnableSRVLookups {
return nil, nil return nil, nil
} }
var newAddrs []resolver.Address var newAddrs []resolver.Address
_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host) _, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
if err != nil { if err != nil {
err = handleDNSError(err, "SRV") // may become nil err = handleDNSError(err, "SRV") // may become nil
return nil, err return nil, err
} }
for _, s := range srvs { for _, s := range srvs {
lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target) lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
if err != nil { if err != nil {
err = handleDNSError(err, "A") // may become nil err = handleDNSError(err, "A") // may become nil
if err == nil { if err == nil {
@ -269,8 +284,8 @@ func handleDNSError(err error, lookupType string) error {
return err return err
} }
func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult { func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host) ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
if err != nil { if err != nil {
if envconfig.TXTErrIgnore { if envconfig.TXTErrIgnore {
return nil return nil
@ -297,8 +312,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
return d.cc.ParseServiceConfig(sc) return d.cc.ParseServiceConfig(sc)
} }
func (d *dnsResolver) lookupHost() ([]resolver.Address, error) { func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(d.ctx, d.host) addrs, err := d.resolver.LookupHost(ctx, d.host)
if err != nil { if err != nil {
err = handleDNSError(err, "A") err = handleDNSError(err, "A")
return nil, err return nil, err
@ -316,8 +331,10 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
} }
func (d *dnsResolver) lookup() (*resolver.State, error) { func (d *dnsResolver) lookup() (*resolver.State, error) {
srv, srvErr := d.lookupSRV() ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
addrs, hostErr := d.lookupHost() defer cancel()
srv, srvErr := d.lookupSRV(ctx)
addrs, hostErr := d.lookupHost(ctx)
if hostErr != nil && (srvErr != nil || len(srv) == 0) { if hostErr != nil && (srvErr != nil || len(srv) == 0) {
return nil, hostErr return nil, hostErr
} }
@ -327,7 +344,7 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv}) state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
} }
if !d.disableServiceConfig { if !d.disableServiceConfig {
state.ServiceConfig = d.lookupTXT() state.ServiceConfig = d.lookupTXT(ctx)
} }
return &state, nil return &state, nil
} }
@ -408,7 +425,7 @@ func chosenByPercentage(a *int) bool {
if a == nil { if a == nil {
return true return true
} }
return grpcrand.Intn(100)+1 <= *a return rand.Intn(100)+1 <= *a
} }
func canaryingSC(js string) string { func canaryingSC(js string) string {

View File

@ -28,7 +28,7 @@ import (
// NetResolver groups the methods on net.Resolver that are used by the DNS // NetResolver groups the methods on net.Resolver that are used by the DNS
// resolver implementation. This allows the default net.Resolver instance to be // resolver implementation. This allows the default net.Resolver instance to be
// overidden from tests. // overridden from tests.
type NetResolver interface { type NetResolver interface {
LookupHost(ctx context.Context, host string) (addrs []string, err error) LookupHost(ctx context.Context, host string) (addrs []string, err error)
LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
@ -50,16 +50,23 @@ var (
// The following vars are overridden from tests. // The following vars are overridden from tests.
var ( var (
// MinResolutionRate is the minimum rate at which re-resolutions are
// allowed. This helps to prevent excessive re-resolution.
MinResolutionRate = 30 * time.Second
// TimeAfterFunc is used by the DNS resolver to wait for the given duration // TimeAfterFunc is used by the DNS resolver to wait for the given duration
// to elapse. In non-test code, this is implemented by time.After. In test // to elapse. In non-test code, this is implemented by time.After. In test
// code, this can be used to control the amount of time the resolver is // code, this can be used to control the amount of time the resolver is
// blocked waiting for the duration to elapse. // blocked waiting for the duration to elapse.
TimeAfterFunc func(time.Duration) <-chan time.Time TimeAfterFunc func(time.Duration) <-chan time.Time
// TimeNowFunc is used by the DNS resolver to get the current time.
// In non-test code, this is implemented by time.Now. In test code,
// this can be used to control the current time for the resolver.
TimeNowFunc func() time.Time
// TimeUntilFunc is used by the DNS resolver to calculate the remaining
// wait time for re-resolution. In non-test code, this is implemented by
// time.Until. In test code, this can be used to control the remaining
// time for resolver to wait for re-resolution.
TimeUntilFunc func(time.Time) time.Duration
// NewNetResolver returns the net.Resolver instance for the given target. // NewNetResolver returns the net.Resolver instance for the given target.
NewNetResolver func(string) (NetResolver, error) NewNetResolver func(string) (NetResolver, error)

42
vendor/google.golang.org/grpc/internal/stats/labels.go generated vendored Normal file
View File

@ -0,0 +1,42 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package stats provides internal stats related functionality.
package stats
import "context"
// Labels are the labels for metrics.
type Labels struct {
// TelemetryLabels are the telemetry labels to record.
TelemetryLabels map[string]string
}
type labelsKey struct{}
// GetLabels returns the Labels stored in the context, or nil if there is one.
func GetLabels(ctx context.Context) *Labels {
labels, _ := ctx.Value(labelsKey{}).(*Labels)
return labels
}
// SetLabels sets the Labels in the context.
func SetLabels(ctx context.Context, labels *Labels) context.Context {
// could also append
return context.WithValue(ctx, labelsKey{}, labels)
}

View File

@ -0,0 +1,95 @@
/*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package stats
import (
"fmt"
estats "google.golang.org/grpc/experimental/stats"
"google.golang.org/grpc/stats"
)
// MetricsRecorderList forwards Record calls to all of its metricsRecorders.
//
// It eats any record calls where the label values provided do not match the
// number of label keys.
type MetricsRecorderList struct {
// metricsRecorders are the metrics recorders this list will forward to.
metricsRecorders []estats.MetricsRecorder
}
// NewMetricsRecorderList creates a new metric recorder list with all the stats
// handlers provided which implement the MetricsRecorder interface.
// If no stats handlers provided implement the MetricsRecorder interface,
// the MetricsRecorder list returned is a no-op.
func NewMetricsRecorderList(shs []stats.Handler) *MetricsRecorderList {
var mrs []estats.MetricsRecorder
for _, sh := range shs {
if mr, ok := sh.(estats.MetricsRecorder); ok {
mrs = append(mrs, mr)
}
}
return &MetricsRecorderList{
metricsRecorders: mrs,
}
}
func verifyLabels(desc *estats.MetricDescriptor, labelsRecv ...string) {
if got, want := len(labelsRecv), len(desc.Labels)+len(desc.OptionalLabels); got != want {
panic(fmt.Sprintf("Received %d labels in call to record metric %q, but expected %d.", got, desc.Name, want))
}
}
func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Count(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Count(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordInt64Histo(handle *estats.Int64HistoHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Histo(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordFloat64Histo(handle *estats.Float64HistoHandle, incr float64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordFloat64Histo(handle, incr, labels...)
}
}
func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle, incr int64, labels ...string) {
verifyLabels(handle.Descriptor(), labels...)
for _, metricRecorder := range l.metricsRecorders {
metricRecorder.RecordInt64Gauge(handle, incr, labels...)
}
}

View File

@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and // combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the // disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for // KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters. // the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error { Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1) unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_KEEPALIVE, 1)

View File

@ -44,7 +44,7 @@ func NetDialerWithTCPKeepalive() *net.Dialer {
// combination of unconditionally enabling TCP keepalives here, and // combination of unconditionally enabling TCP keepalives here, and
// disabling the overriding of TCP keepalive parameters by setting the // disabling the overriding of TCP keepalive parameters by setting the
// KeepAlive field to a negative value above, results in OS defaults for // KeepAlive field to a negative value above, results in OS defaults for
// the TCP keealive interval and time parameters. // the TCP keepalive interval and time parameters.
Control: func(_, _ string, c syscall.RawConn) error { Control: func(_, _ string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {
windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, 1) windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_KEEPALIVE, 1)

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -148,9 +149,9 @@ type dataFrame struct {
streamID uint32 streamID uint32
endStream bool endStream bool
h []byte h []byte
d []byte reader mem.Reader
// onEachWrite is called every time // onEachWrite is called every time
// a part of d is written out. // a part of data is written out.
onEachWrite func() onEachWrite func()
} }
@ -193,7 +194,7 @@ type goAway struct {
code http2.ErrCode code http2.ErrCode
debugData []byte debugData []byte
headsUp bool headsUp bool
closeConn error // if set, loopyWriter will exit, resulting in conn closure closeConn error // if set, loopyWriter will exit with this error
} }
func (*goAway) isTransportResponseFrame() bool { return false } func (*goAway) isTransportResponseFrame() bool { return false }
@ -289,18 +290,22 @@ func (l *outStreamList) dequeue() *outStream {
} }
// controlBuffer is a way to pass information to loopy. // controlBuffer is a way to pass information to loopy.
// Information is passed as specific struct types called control frames. //
// A control frame not only represents data, messages or headers to be sent out // Information is passed as specific struct types called control frames. A
// but can also be used to instruct loopy to update its internal state. // control frame not only represents data, messages or headers to be sent out
// It shouldn't be confused with an HTTP2 frame, although some of the control frames // but can also be used to instruct loopy to update its internal state. It
// like dataFrame and headerFrame do go out on wire as HTTP2 frames. // shouldn't be confused with an HTTP2 frame, although some of the control
// frames like dataFrame and headerFrame do go out on wire as HTTP2 frames.
type controlBuffer struct { type controlBuffer struct {
ch chan struct{} wakeupCh chan struct{} // Unblocks readers waiting for something to read.
done <-chan struct{} done <-chan struct{} // Closed when the transport is done.
// Mutex guards all the fields below, except trfChan which can be read
// atomically without holding mu.
mu sync.Mutex mu sync.Mutex
consumerWaiting bool consumerWaiting bool // True when readers are blocked waiting for new data.
list *itemList closed bool // True when the controlbuf is finished.
err error list *itemList // List of queued control frames.
// transportResponseFrames counts the number of queued items that represent // transportResponseFrames counts the number of queued items that represent
// the response of an action initiated by the peer. trfChan is created // the response of an action initiated by the peer. trfChan is created
@ -308,47 +313,59 @@ type controlBuffer struct {
// closed and nilled when transportResponseFrames drops below the // closed and nilled when transportResponseFrames drops below the
// threshold. Both fields are protected by mu. // threshold. Both fields are protected by mu.
transportResponseFrames int transportResponseFrames int
trfChan atomic.Value // chan struct{} trfChan atomic.Pointer[chan struct{}]
} }
func newControlBuffer(done <-chan struct{}) *controlBuffer { func newControlBuffer(done <-chan struct{}) *controlBuffer {
return &controlBuffer{ return &controlBuffer{
ch: make(chan struct{}, 1), wakeupCh: make(chan struct{}, 1),
list: &itemList{}, list: &itemList{},
done: done, done: done,
} }
} }
// throttle blocks if there are too many incomingSettings/cleanupStreams in the // throttle blocks if there are too many frames in the control buf that
// controlbuf. // represent the response of an action initiated by the peer, like
// incomingSettings cleanupStreams etc.
func (c *controlBuffer) throttle() { func (c *controlBuffer) throttle() {
ch, _ := c.trfChan.Load().(chan struct{}) if ch := c.trfChan.Load(); ch != nil {
if ch != nil {
select { select {
case <-ch: case <-(*ch):
case <-c.done: case <-c.done:
} }
} }
} }
// put adds an item to the controlbuf.
func (c *controlBuffer) put(it cbItem) error { func (c *controlBuffer) put(it cbItem) error {
_, err := c.executeAndPut(nil, it) _, err := c.executeAndPut(nil, it)
return err return err
} }
func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, error) { // executeAndPut runs f, and if the return value is true, adds the given item to
var wakeUp bool // the controlbuf. The item could be nil, in which case, this method simply
// executes f and does not add the item to the controlbuf.
//
// The first return value indicates whether the item was successfully added to
// the control buffer. A non-nil error, specifically ErrConnClosing, is returned
// if the control buffer is already closed.
func (c *controlBuffer) executeAndPut(f func() bool, it cbItem) (bool, error) {
c.mu.Lock() c.mu.Lock()
if c.err != nil { defer c.mu.Unlock()
c.mu.Unlock()
return false, c.err if c.closed {
return false, ErrConnClosing
} }
if f != nil { if f != nil {
if !f(it) { // f wasn't successful if !f() { // f wasn't successful
c.mu.Unlock()
return false, nil return false, nil
} }
} }
if it == nil {
return true, nil
}
var wakeUp bool
if c.consumerWaiting { if c.consumerWaiting {
wakeUp = true wakeUp = true
c.consumerWaiting = false c.consumerWaiting = false
@ -359,98 +376,102 @@ func (c *controlBuffer) executeAndPut(f func(it any) bool, it cbItem) (bool, err
if c.transportResponseFrames == maxQueuedTransportResponseFrames { if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are adding the frame that puts us over the threshold; create // We are adding the frame that puts us over the threshold; create
// a throttling channel. // a throttling channel.
c.trfChan.Store(make(chan struct{})) ch := make(chan struct{})
c.trfChan.Store(&ch)
} }
} }
c.mu.Unlock()
if wakeUp { if wakeUp {
select { select {
case c.ch <- struct{}{}: case c.wakeupCh <- struct{}{}:
default: default:
} }
} }
return true, nil return true, nil
} }
// Note argument f should never be nil. // get returns the next control frame from the control buffer. If block is true
func (c *controlBuffer) execute(f func(it any) bool, it any) (bool, error) { // **and** there are no control frames in the control buffer, the call blocks
c.mu.Lock() // until one of the conditions is met: there is a frame to return or the
if c.err != nil { // transport is closed.
c.mu.Unlock()
return false, c.err
}
if !f(it) { // f wasn't successful
c.mu.Unlock()
return false, nil
}
c.mu.Unlock()
return true, nil
}
func (c *controlBuffer) get(block bool) (any, error) { func (c *controlBuffer) get(block bool) (any, error) {
for { for {
c.mu.Lock() c.mu.Lock()
if c.err != nil { frame, err := c.getOnceLocked()
if frame != nil || err != nil || !block {
// If we read a frame or an error, we can return to the caller. The
// call to getOnceLocked() returns a nil frame and a nil error if
// there is nothing to read, and in that case, if the caller asked
// us not to block, we can return now as well.
c.mu.Unlock() c.mu.Unlock()
return nil, c.err return frame, err
}
if !c.list.isEmpty() {
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Load().(chan struct{})
close(ch)
c.trfChan.Store((chan struct{})(nil))
}
c.transportResponseFrames--
}
c.mu.Unlock()
return h, nil
}
if !block {
c.mu.Unlock()
return nil, nil
} }
c.consumerWaiting = true c.consumerWaiting = true
c.mu.Unlock() c.mu.Unlock()
// Release the lock above and wait to be woken up.
select { select {
case <-c.ch: case <-c.wakeupCh:
case <-c.done: case <-c.done:
return nil, errors.New("transport closed by client") return nil, errors.New("transport closed by client")
} }
} }
} }
// Callers must not use this method, but should instead use get().
//
// Caller must hold c.mu.
func (c *controlBuffer) getOnceLocked() (any, error) {
if c.closed {
return false, ErrConnClosing
}
if c.list.isEmpty() {
return nil, nil
}
h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Swap(nil)
close(*ch)
}
c.transportResponseFrames--
}
return h, nil
}
// finish closes the control buffer, cleaning up any streams that have queued
// header frames. Once this method returns, no more frames can be added to the
// control buffer, and attempts to do so will return ErrConnClosing.
func (c *controlBuffer) finish() { func (c *controlBuffer) finish() {
c.mu.Lock() c.mu.Lock()
if c.err != nil { defer c.mu.Unlock()
c.mu.Unlock()
if c.closed {
return return
} }
c.err = ErrConnClosing c.closed = true
// There may be headers for streams in the control buffer. // There may be headers for streams in the control buffer.
// These streams need to be cleaned out since the transport // These streams need to be cleaned out since the transport
// is still not aware of these yet. // is still not aware of these yet.
for head := c.list.dequeueAll(); head != nil; head = head.next { for head := c.list.dequeueAll(); head != nil; head = head.next {
hdr, ok := head.it.(*headerFrame) switch v := head.it.(type) {
if !ok { case *headerFrame:
continue if v.onOrphaned != nil { // It will be nil on the server-side.
v.onOrphaned(ErrConnClosing)
} }
if hdr.onOrphaned != nil { // It will be nil on the server-side. case *dataFrame:
hdr.onOrphaned(ErrConnClosing) _ = v.reader.Close()
} }
} }
// In case throttle() is currently in flight, it needs to be unblocked. // In case throttle() is currently in flight, it needs to be unblocked.
// Otherwise, the transport may not close, since the transport is closed by // Otherwise, the transport may not close, since the transport is closed by
// the reader encountering the connection error. // the reader encountering the connection error.
ch, _ := c.trfChan.Load().(chan struct{}) ch := c.trfChan.Swap(nil)
if ch != nil { if ch != nil {
close(ch) close(*ch)
} }
c.trfChan.Store((chan struct{})(nil))
c.mu.Unlock()
} }
type side int type side int
@ -466,7 +487,7 @@ const (
// stream maintains a queue of data frames; as loopy receives data frames // stream maintains a queue of data frames; as loopy receives data frames
// it gets added to the queue of the relevant stream. // it gets added to the queue of the relevant stream.
// Loopy goes over this list of active streams by processing one node every iteration, // Loopy goes over this list of active streams by processing one node every iteration,
// thereby closely resemebling to a round-robin scheduling over all streams. While // thereby closely resembling a round-robin scheduling over all streams. While
// processing a stream, loopy writes out data bytes from this stream capped by the min // processing a stream, loopy writes out data bytes from this stream capped by the min
// of http2MaxFrameLen, connection-level flow control and stream-level flow control. // of http2MaxFrameLen, connection-level flow control and stream-level flow control.
type loopyWriter struct { type loopyWriter struct {
@ -490,12 +511,13 @@ type loopyWriter struct {
draining bool draining bool
conn net.Conn conn net.Conn
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
// Side-specific handlers // Side-specific handlers
ssGoAwayHandler func(*goAway) (bool, error) ssGoAwayHandler func(*goAway) (bool, error)
} }
func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger) *loopyWriter { func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter {
var buf bytes.Buffer var buf bytes.Buffer
l := &loopyWriter{ l := &loopyWriter{
side: s, side: s,
@ -510,6 +532,8 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato
bdpEst: bdpEst, bdpEst: bdpEst,
conn: conn, conn: conn,
logger: logger, logger: logger,
ssGoAwayHandler: goAwayHandler,
bufferPool: bufferPool,
} }
return l return l
} }
@ -767,6 +791,11 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
// not be established yet. // not be established yet.
delete(l.estdStreams, c.streamID) delete(l.estdStreams, c.streamID)
str.deleteSelf() str.deleteSelf()
for head := str.itl.dequeueAll(); head != nil; head = head.next {
if df, ok := head.it.(*dataFrame); ok {
_ = df.reader.Close()
}
}
} }
if c.rst { // If RST_STREAM needs to be sent. if c.rst { // If RST_STREAM needs to be sent.
if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil { if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil {
@ -902,16 +931,18 @@ func (l *loopyWriter) processData() (bool, error) {
dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream. dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream.
// A data item is represented by a dataFrame, since it later translates into // A data item is represented by a dataFrame, since it later translates into
// multiple HTTP2 data frames. // multiple HTTP2 data frames.
// Every dataFrame has two buffers; h that keeps grpc-message header and d that is actual data. // Every dataFrame has two buffers; h that keeps grpc-message header and data
// As an optimization to keep wire traffic low, data from d is copied to h to make as big as the // that is the actual message. As an optimization to keep wire traffic low, data
// maximum possible HTTP2 frame size. // from data is copied to h to make as big as the maximum possible HTTP2 frame
// size.
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame if len(dataItem.h) == 0 && dataItem.reader.Remaining() == 0 { // Empty data frame
// Client sends out empty data frame with endStream = true // Client sends out empty data frame with endStream = true
if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil {
return false, err return false, err
} }
str.itl.dequeue() // remove the empty data item from stream str.itl.dequeue() // remove the empty data item from stream
_ = dataItem.reader.Close()
if str.itl.isEmpty() { if str.itl.isEmpty() {
str.state = empty str.state = empty
} else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers.
@ -926,9 +957,7 @@ func (l *loopyWriter) processData() (bool, error) {
} }
return false, nil return false, nil
} }
var (
buf []byte
)
// Figure out the maximum size we can send // Figure out the maximum size we can send
maxSize := http2MaxFrameLen maxSize := http2MaxFrameLen
if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control. if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control.
@ -942,43 +971,50 @@ func (l *loopyWriter) processData() (bool, error) {
} }
// Compute how much of the header and data we can send within quota and max frame length // Compute how much of the header and data we can send within quota and max frame length
hSize := min(maxSize, len(dataItem.h)) hSize := min(maxSize, len(dataItem.h))
dSize := min(maxSize-hSize, len(dataItem.d)) dSize := min(maxSize-hSize, dataItem.reader.Remaining())
if hSize != 0 { remainingBytes := len(dataItem.h) + dataItem.reader.Remaining() - hSize - dSize
if dSize == 0 {
buf = dataItem.h
} else {
// We can add some data to grpc message header to distribute bytes more equally across frames.
// Copy on the stack to avoid generating garbage
var localBuf [http2MaxFrameLen]byte
copy(localBuf[:hSize], dataItem.h)
copy(localBuf[hSize:], dataItem.d[:dSize])
buf = localBuf[:hSize+dSize]
}
} else {
buf = dataItem.d
}
size := hSize + dSize size := hSize + dSize
var buf *[]byte
if hSize != 0 && dSize == 0 {
buf = &dataItem.h
} else {
// Note: this is only necessary because the http2.Framer does not support
// partially writing a frame, so the sequence must be materialized into a buffer.
// TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed.
pool := l.bufferPool
if pool == nil {
// Note that this is only supposed to be nil in tests. Otherwise, stream is
// always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
buf = pool.Get(size)
defer pool.Put(buf)
copy((*buf)[:hSize], dataItem.h)
_, _ = dataItem.reader.Read((*buf)[hSize:])
}
// Now that outgoing flow controls are checked we can replenish str's write quota // Now that outgoing flow controls are checked we can replenish str's write quota
str.wq.replenish(size) str.wq.replenish(size)
var endStream bool var endStream bool
// If this is the last data message on this stream and all of it can be written in this iteration. // If this is the last data message on this stream and all of it can be written in this iteration.
if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size { if dataItem.endStream && remainingBytes == 0 {
endStream = true endStream = true
} }
if dataItem.onEachWrite != nil { if dataItem.onEachWrite != nil {
dataItem.onEachWrite() dataItem.onEachWrite()
} }
if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil {
return false, err return false, err
} }
str.bytesOutStanding += size str.bytesOutStanding += size
l.sendQuota -= uint32(size) l.sendQuota -= uint32(size)
dataItem.h = dataItem.h[hSize:] dataItem.h = dataItem.h[hSize:]
dataItem.d = dataItem.d[dSize:]
if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. if remainingBytes == 0 { // All the data from that message was written out.
_ = dataItem.reader.Close()
str.itl.dequeue() str.itl.dequeue()
} }
if str.itl.isEmpty() { if str.itl.isEmpty() {

View File

@ -24,7 +24,6 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -40,6 +39,7 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -50,15 +50,11 @@ import (
// NewServerHandlerTransport returns a ServerTransport handling gRPC from // NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error. // inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2. // It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) { func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
if r.ProtoMajor != 2 { if r.Method != http.MethodPost {
msg := "gRPC requires HTTP/2" w.Header().Set("Allow", http.MethodPost)
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if r.Method != "POST" {
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
http.Error(w, msg, http.StatusBadRequest) http.Error(w, msg, http.StatusMethodNotAllowed)
return nil, errors.New(msg) return nil, errors.New(msg)
} }
contentType := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
@ -69,6 +65,11 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
http.Error(w, msg, http.StatusUnsupportedMediaType) http.Error(w, msg, http.StatusUnsupportedMediaType)
return nil, errors.New(msg) return nil, errors.New(msg)
} }
if r.ProtoMajor != 2 {
msg := "gRPC requires HTTP/2"
http.Error(w, msg, http.StatusHTTPVersionNotSupported)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok { if _, ok := w.(http.Flusher); !ok {
msg := "gRPC requires a ResponseWriter supporting http.Flusher" msg := "gRPC requires a ResponseWriter supporting http.Flusher"
http.Error(w, msg, http.StatusInternalServerError) http.Error(w, msg, http.StatusInternalServerError)
@ -97,6 +98,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
contentType: contentType, contentType: contentType,
contentSubtype: contentSubtype, contentSubtype: contentSubtype,
stats: stats, stats: stats,
bufferPool: bufferPool,
} }
st.logger = prefixLoggerForServerHandlerTransport(st) st.logger = prefixLoggerForServerHandlerTransport(st)
@ -170,6 +172,8 @@ type serverHandlerTransport struct {
stats []stats.Handler stats []stats.Handler
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
bufferPool mem.BufferPool
} }
func (ht *serverHandlerTransport) Close(err error) { func (ht *serverHandlerTransport) Close(err error) {
@ -243,6 +247,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
} }
s.hdrMu.Lock() s.hdrMu.Lock()
defer s.hdrMu.Unlock()
if p := st.Proto(); p != nil && len(p.Details) > 0 { if p := st.Proto(); p != nil && len(p.Details) > 0 {
delete(s.trailer, grpcStatusDetailsBinHeader) delete(s.trailer, grpcStatusDetailsBinHeader)
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
@ -267,7 +272,6 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
} }
} }
} }
s.hdrMu.Unlock()
}) })
if err == nil { // transport has not been closed if err == nil { // transport has not been closed
@ -329,16 +333,28 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
s.hdrMu.Unlock() s.hdrMu.Unlock()
} }
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
// Always take a reference because otherwise there is no guarantee the data will
// be available after this function returns. This is what callers to Write
// expect.
data.Ref()
headersWritten := s.updateHeaderSent() headersWritten := s.updateHeaderSent()
return ht.do(func() { err := ht.do(func() {
defer data.Free()
if !headersWritten { if !headersWritten {
ht.writePendingHeaders(s) ht.writePendingHeaders(s)
} }
ht.rw.Write(hdr) ht.rw.Write(hdr)
ht.rw.Write(data) for _, b := range data {
_, _ = ht.rw.Write(b.ReadOnlyData())
}
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
}) })
if err != nil {
data.Free()
return err
}
return nil
} }
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
@ -405,7 +421,7 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
headerWireLength: 0, // won't have access to header wire length until golang/go#18997. headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
@ -414,21 +430,19 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
go func() { go func() {
defer close(readerDone) defer close(readerDone)
// TODO: minimize garbage, optimize recvBuffer code/ownership for {
const readSize = 8196 buf := ht.bufferPool.Get(http2MaxFrameLen)
for buf := make([]byte, readSize); ; { n, err := req.Body.Read(*buf)
n, err := req.Body.Read(buf)
if n > 0 { if n > 0 {
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])}) *buf = (*buf)[:n]
buf = buf[n:] s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
} else {
ht.bufferPool.Put(buf)
} }
if err != nil { if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)}) s.buf.put(recvMsg{err: mapRecvMsgError(err)})
return return
} }
if len(buf) == 0 {
buf = make([]byte, readSize)
}
} }
}() }()

View File

@ -47,6 +47,7 @@ import (
isyscall "google.golang.org/grpc/internal/syscall" isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -59,6 +60,8 @@ import (
// atomically. // atomically.
var clientConnectionCounter uint64 var clientConnectionCounter uint64
var goAwayLoopyWriterTimeout = 5 * time.Second
var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))
// http2Client implements the ClientTransport interface with HTTP2. // http2Client implements the ClientTransport interface with HTTP2.
@ -114,11 +117,11 @@ type http2Client struct {
streamQuota int64 streamQuota int64
streamsQuotaAvailable chan struct{} streamsQuotaAvailable chan struct{}
waitingStreams uint32 waitingStreams uint32
nextID uint32
registeredCompressors string registeredCompressors string
// Do not access controlBuf with mu held. // Do not access controlBuf with mu held.
mu sync.Mutex // guard the following variables mu sync.Mutex // guard the following variables
nextID uint32
state transportState state transportState
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
@ -140,13 +143,11 @@ type http2Client struct {
// variable. // variable.
kpDormant bool kpDormant bool
// Fields below are for channelz metric collection. channelz *channelz.Socket
channelzID *channelz.Identifier
czData *channelzData
onClose func(GoAwayReason) onClose func(GoAwayReason)
bufferPool *bufferPool bufferPool mem.BufferPool
connectionID uint64 connectionID uint64
logger *grpclog.PrefixLogger logger *grpclog.PrefixLogger
@ -231,7 +232,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
} }
}(conn) }(conn)
// The following defer and goroutine monitor the connectCtx for cancelation // The following defer and goroutine monitor the connectCtx for cancellation
// and deadline. On context expiration, the connection is hard closed and // and deadline. On context expiration, the connection is hard closed and
// this function will naturally fail as a result. Otherwise, the defer // this function will naturally fail as a result. Otherwise, the defer
// waits for the goroutine to exit to prevent the context from being // waits for the goroutine to exit to prevent the context from being
@ -319,6 +320,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if opts.MaxHeaderListSize != nil { if opts.MaxHeaderListSize != nil {
maxHeaderListSize = *opts.MaxHeaderListSize maxHeaderListSize = *opts.MaxHeaderListSize
} }
t := &http2Client{ t := &http2Client{
ctx: ctx, ctx: ctx,
ctxDone: ctx.Done(), // Cache Done chan. ctxDone: ctx.Done(), // Cache Done chan.
@ -346,11 +348,25 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
maxConcurrentStreams: defaultMaxStreamsClient, maxConcurrentStreams: defaultMaxStreamsClient,
streamQuota: defaultMaxStreamsClient, streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1), streamsQuotaAvailable: make(chan struct{}, 1),
czData: new(channelzData),
keepaliveEnabled: keepaliveEnabled, keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(), bufferPool: opts.BufferPool,
onClose: onClose, onClose: onClose,
} }
var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
czSecurity = au.GetSecurityValue()
}
t.channelz = channelz.RegisterSocket(
&channelz.Socket{
SocketType: channelz.SocketTypeNormal,
Parent: opts.ChannelzParent,
SocketMetrics: channelz.SocketMetrics{},
EphemeralMetrics: t.socketMetrics,
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
SocketOptions: channelz.GetSocketOption(t.conn),
Security: czSecurity,
})
t.logger = prefixLoggerForClientTransport(t) t.logger = prefixLoggerForClientTransport(t)
// Add peer information to the http2client context. // Add peer information to the http2client context.
t.ctx = peer.NewContext(t.ctx, t.getPeer()) t.ctx = peer.NewContext(t.ctx, t.getPeer())
@ -381,10 +397,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
} }
sh.HandleConn(t.ctx, connBegin) sh.HandleConn(t.ctx, connBegin)
} }
t.channelzID, err = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr))
if err != nil {
return nil, err
}
if t.keepaliveEnabled { if t.keepaliveEnabled {
t.kpDormancyCond = sync.NewCond(&t.mu) t.kpDormancyCond = sync.NewCond(&t.mu)
go t.keepalive() go t.keepalive()
@ -399,10 +411,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerErrCh := make(chan error, 1) readerErrCh := make(chan error, 1)
go t.reader(readerErrCh) go t.reader(readerErrCh)
defer func() { defer func() {
if err == nil {
err = <-readerErrCh
}
if err != nil { if err != nil {
// writerDone should be closed since the loopy goroutine
// wouldn't have started in the case this function returns an error.
close(t.writerDone)
t.Close(err) t.Close(err)
} }
}() }()
@ -449,8 +461,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
if err := t.framer.writer.Flush(); err != nil { if err := t.framer.writer.Flush(); err != nil {
return nil, err return nil, err
} }
// Block until the server preface is received successfully or an error occurs.
if err = <-readerErrCh; err != nil {
return nil, err
}
go func() { go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
if err := t.loopy.run(); !isIOError(err) { if err := t.loopy.run(); !isIOError(err) {
// Immediately close the connection, as the loopy writer returns // Immediately close the connection, as the loopy writer returns
// when there are no more active streams and we were draining (the // when there are no more active streams and we were draining (the
@ -491,7 +507,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
closeStream: func(err error) { closeStream: func(err error) {
t.CloseStream(s, err) t.CloseStream(s, err)
}, },
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -508,6 +523,17 @@ func (t *http2Client) getPeer() *peer.Peer {
} }
} }
// OutgoingGoAwayHandler writes a GOAWAY to the connection. Always returns (false, err) as we want the GoAway
// to be the last frame loopy writes to the transport.
func (t *http2Client) outgoingGoAwayHandler(g *goAway) (bool, error) {
t.mu.Lock()
defer t.mu.Unlock()
if err := t.framer.fr.WriteGoAway(t.nextID-2, http2.ErrCodeNo, g.debugData); err != nil {
return false, err
}
return false, g.closeConn
}
func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) {
aud := t.createAudience(callHdr) aud := t.createAudience(callHdr)
ri := credentials.RequestInfo{ ri := credentials.RequestInfo{
@ -756,8 +782,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return ErrConnClosing return ErrConnClosing
} }
if channelz.IsOn() { if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1) t.channelz.SocketMetrics.StreamsStarted.Add(1)
atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastLocalStreamCreatedTimestamp.Store(time.Now().UnixNano())
} }
// If the keepalive goroutine has gone dormant, wake it up. // If the keepalive goroutine has gone dormant, wake it up.
if t.kpDormant { if t.kpDormant {
@ -772,7 +798,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
firstTry := true firstTry := true
var ch chan struct{} var ch chan struct{}
transportDrainRequired := false transportDrainRequired := false
checkForStreamQuota := func(it any) bool { checkForStreamQuota := func() bool {
if t.streamQuota <= 0 { // Can go negative if server decreases it. if t.streamQuota <= 0 { // Can go negative if server decreases it.
if firstTry { if firstTry {
t.waitingStreams++ t.waitingStreams++
@ -784,23 +810,24 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
t.waitingStreams-- t.waitingStreams--
} }
t.streamQuota-- t.streamQuota--
h := it.(*headerFrame)
h.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.mu.Lock() t.mu.Lock()
if t.state == draining || t.activeStreams == nil { // Can be niled from Close(). if t.state == draining || t.activeStreams == nil { // Can be niled from Close().
t.mu.Unlock() t.mu.Unlock()
return false // Don't create a stream if the transport is already closed. return false // Don't create a stream if the transport is already closed.
} }
hdr.streamID = t.nextID
t.nextID += 2
// Drain client transport if nextID > MaxStreamID which signals gRPC that
// the connection is closed and a new one must be created for subsequent RPCs.
transportDrainRequired = t.nextID > MaxStreamID
s.id = hdr.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
if t.streamQuota > 0 && t.waitingStreams > 0 { if t.streamQuota > 0 && t.waitingStreams > 0 {
select { select {
case t.streamsQuotaAvailable <- struct{}{}: case t.streamsQuotaAvailable <- struct{}{}:
@ -810,13 +837,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true return true
} }
var hdrListSizeErr error var hdrListSizeErr error
checkForHeaderListSize := func(it any) bool { checkForHeaderListSize := func() bool {
if t.maxSendHeaderListSize == nil { if t.maxSendHeaderListSize == nil {
return true return true
} }
hdrFrame := it.(*headerFrame)
var sz int64 var sz int64
for _, f := range hdrFrame.hf { for _, f := range hdr.hf {
if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) {
hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize) hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize)
return false return false
@ -825,8 +851,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
return true return true
} }
for { for {
success, err := t.controlBuf.executeAndPut(func(it any) bool { success, err := t.controlBuf.executeAndPut(func() bool {
return checkForHeaderListSize(it) && checkForStreamQuota(it) return checkForHeaderListSize() && checkForStreamQuota()
}, hdr) }, hdr)
if err != nil { if err != nil {
// Connection closed. // Connection closed.
@ -928,16 +954,16 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
t.mu.Unlock() t.mu.Unlock()
if channelz.IsOn() { if channelz.IsOn() {
if eosReceived { if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1) t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else { } else {
atomic.AddInt64(&t.czData.streamsFailed, 1) t.channelz.SocketMetrics.StreamsFailed.Add(1)
} }
} }
}, },
rst: rst, rst: rst,
rstCode: rstCode, rstCode: rstCode,
} }
addBackStreamQuota := func(any) bool { addBackStreamQuota := func() bool {
t.streamQuota++ t.streamQuota++
if t.streamQuota > 0 && t.waitingStreams > 0 { if t.streamQuota > 0 && t.waitingStreams > 0 {
select { select {
@ -957,8 +983,9 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// Close kicks off the shutdown process of the transport. This should be called // Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be // only once on a transport. Once it is called, the transport should not be
// accessed any more. // accessed anymore.
func (t *http2Client) Close(err error) { func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
t.mu.Lock() t.mu.Lock()
// Make sure we only close once. // Make sure we only close once.
if t.state == closing { if t.state == closing {
@ -982,10 +1009,23 @@ func (t *http2Client) Close(err error) {
t.kpDormancyCond.Signal() t.kpDormancyCond.Signal()
} }
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.finish()
// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It
// also waits for loopyWriter to be closed with a timer to avoid the
// long blocking in case the connection is blackholed, i.e. TCP is
// just stuck.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
timer := time.NewTimer(goAwayLoopyWriterTimeout)
defer timer.Stop()
select {
case <-t.writerDone: // success
case <-timer.C:
t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout)
}
t.cancel() t.cancel()
t.conn.Close() t.conn.Close()
channelz.RemoveEntry(t.channelzID) channelz.RemoveEntry(t.channelz.ID)
// Append info about previous goaways if there were any, since this may be important // Append info about previous goaways if there were any, since this may be important
// for understanding the root cause for this connection to be closed. // for understanding the root cause for this connection to be closed.
_, goAwayDebugMessage := t.GetGoAwayReason() _, goAwayDebugMessage := t.GetGoAwayReason()
@ -1038,27 +1078,36 @@ func (t *http2Client) GracefulClose() {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Client) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
reader := data.Reader()
if opts.Last { if opts.Last {
// If it's the last message, update stream state. // If it's the last message, update stream state.
if !s.compareAndSwapState(streamActive, streamWriteDone) { if !s.compareAndSwapState(streamActive, streamWriteDone) {
_ = reader.Close()
return errStreamDone return errStreamDone
} }
} else if s.getState() != streamActive { } else if s.getState() != streamActive {
_ = reader.Close()
return errStreamDone return errStreamDone
} }
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
endStream: opts.Last, endStream: opts.Last,
h: hdr, h: hdr,
d: data, reader: reader,
} }
if hdr != nil || data != nil { // If it's not an empty data frame, check quota. if hdr != nil || df.reader.Remaining() != 0 { // If it's not an empty data frame, check quota.
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return err return err
} }
} }
return t.controlBuf.put(df) if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
return nil
} }
func (t *http2Client) getStream(f http2.Frame) *Stream { func (t *http2Client) getStream(f http2.Frame) *Stream {
@ -1090,7 +1139,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
// for the transport and the stream based on the current bdp // for the transport and the stream based on the current bdp
// estimation. // estimation.
func (t *http2Client) updateFlowControl(n uint32) { func (t *http2Client) updateFlowControl(n uint32) {
updateIWS := func(any) bool { updateIWS := func() bool {
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
t.mu.Lock() t.mu.Lock()
for _, s := range t.activeStreams { for _, s := range t.activeStreams {
@ -1163,10 +1212,13 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
buffer := t.bufferPool.get() pool := t.bufferPool
buffer.Reset() if pool == nil {
buffer.Write(f.Data()) // Note that this is only supposed to be nil in tests. Otherwise, stream is
s.write(recvMsg{buffer: buffer}) // always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
@ -1195,7 +1247,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if statusCode == codes.Canceled { if statusCode == codes.Canceled {
if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) { if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
// Our deadline was already exceeded, and that was likely the cause // Our deadline was already exceeded, and that was likely the cause
// of this cancelation. Alter the status code accordingly. // of this cancellation. Alter the status code accordingly.
statusCode = codes.DeadlineExceeded statusCode = codes.DeadlineExceeded
} }
} }
@ -1243,7 +1295,7 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {
} }
updateFuncs = append(updateFuncs, updateStreamQuota) updateFuncs = append(updateFuncs, updateStreamQuota)
} }
t.controlBuf.executeAndPut(func(any) bool { t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs { for _, f := range updateFuncs {
f() f()
} }
@ -1280,7 +1332,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID id := f.LastStreamID
if id > 0 && id%2 == 0 { if id > 0 && id%2 == 0 {
t.mu.Unlock() t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id))
return return
} }
// A client can receive multiple GoAways from the server (see // A client can receive multiple GoAways from the server (see
@ -1708,7 +1760,7 @@ func (t *http2Client) keepalive() {
// keepalive timer expired. In both cases, we need to send a ping. // keepalive timer expired. In both cases, we need to send a ping.
if !outstandingPing { if !outstandingPing {
if channelz.IsOn() { if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1) t.channelz.SocketMetrics.KeepAlivesSent.Add(1)
} }
t.controlBuf.put(p) t.controlBuf.put(p)
timeoutLeft = t.kp.Timeout timeoutLeft = t.kp.Timeout
@ -1738,40 +1790,23 @@ func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway return t.goAway
} }
func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric { func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics {
s := channelz.SocketInternalMetric{ return &channelz.EphemeralSocketMetrics{
StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted),
StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded),
StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed),
MessagesSent: atomic.LoadInt64(&t.czData.msgSent),
MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv),
KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount),
LastLocalStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)),
LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)),
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()), LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn), RemoteFlowControlWindow: t.getOutFlowWindow(),
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
// RemoteName :
} }
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
} }
func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr } func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr }
func (t *http2Client) IncrMsgSent() { func (t *http2Client) IncrMsgSent() {
atomic.AddInt64(&t.czData.msgSent, 1) t.channelz.SocketMetrics.MessagesSent.Add(1)
atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastMessageSentTimestamp.Store(time.Now().UnixNano())
} }
func (t *http2Client) IncrMsgRecv() { func (t *http2Client) IncrMsgRecv() {
atomic.AddInt64(&t.czData.msgRecv, 1) t.channelz.SocketMetrics.MessagesReceived.Add(1)
atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Store(time.Now().UnixNano())
} }
func (t *http2Client) getOutFlowWindow() int64 { func (t *http2Client) getOutFlowWindow() int64 {

View File

@ -25,6 +25,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"math/rand"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
@ -38,12 +39,12 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/mem"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -118,9 +119,8 @@ type http2Server struct {
idle time.Time idle time.Time
// Fields below are for channelz metric collection. // Fields below are for channelz metric collection.
channelzID *channelz.Identifier channelz *channelz.Socket
czData *channelzData bufferPool mem.BufferPool
bufferPool *bufferPool
connectionID uint64 connectionID uint64
@ -262,9 +262,24 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
idle: time.Now(), idle: time.Now(),
kep: kep, kep: kep,
initialWindowSize: iwz, initialWindowSize: iwz,
czData: new(channelzData), bufferPool: config.BufferPool,
bufferPool: newBufferPool(),
} }
var czSecurity credentials.ChannelzSecurityValue
if au, ok := authInfo.(credentials.ChannelzSecurityInfo); ok {
czSecurity = au.GetSecurityValue()
}
t.channelz = channelz.RegisterSocket(
&channelz.Socket{
SocketType: channelz.SocketTypeNormal,
Parent: config.ChannelzParent,
SocketMetrics: channelz.SocketMetrics{},
EphemeralMetrics: t.socketMetrics,
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
SocketOptions: channelz.GetSocketOption(t.conn),
Security: czSecurity,
},
)
t.logger = prefixLoggerForServerTransport(t) t.logger = prefixLoggerForServerTransport(t)
t.controlBuf = newControlBuffer(t.done) t.controlBuf = newControlBuffer(t.done)
@ -274,10 +289,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl, updateFlowControl: t.updateFlowControl,
} }
} }
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1) t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1)
t.framer.writer.Flush() t.framer.writer.Flush()
@ -320,8 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
t.handleSettings(sf) t.handleSettings(sf)
go func() { go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger, t.outgoingGoAwayHandler, t.bufferPool)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
err := t.loopy.run() err := t.loopy.run()
close(t.loopyWriterDone) close(t.loopyWriterDone)
if !isIOError(err) { if !isIOError(err) {
@ -334,9 +344,11 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// closed, would lead to a TCP RST instead of FIN, and the client // closed, would lead to a TCP RST instead of FIN, and the client
// encountering errors. For more info: // encountering errors. For more info:
// https://github.com/grpc/grpc-go/issues/5358 // https://github.com/grpc/grpc-go/issues/5358
timer := time.NewTimer(time.Second)
defer timer.Stop()
select { select {
case <-t.readerDone: case <-t.readerDone:
case <-time.After(time.Second): case <-timer.C:
} }
t.conn.Close() t.conn.Close()
} }
@ -592,8 +604,8 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
} }
t.mu.Unlock() t.mu.Unlock()
if channelz.IsOn() { if channelz.IsOn() {
atomic.AddInt64(&t.czData.streamsStarted, 1) t.channelz.SocketMetrics.StreamsStarted.Add(1)
atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano())
} }
s.requestRead = func(n int) { s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n)) t.adjustWindow(s, uint32(n))
@ -605,7 +617,6 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctxDone, ctxDone: s.ctxDone,
recv: s.buf, recv: s.buf,
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -658,8 +669,14 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
if err := t.operateHeaders(ctx, frame, handle); err != nil { if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err) // Any error processing client headers, e.g. invalid stream ID,
break // is considered a protocol violation.
t.controlBuf.put(&goAway{
code: http2.ErrCodeProtocol,
debugData: []byte(err.Error()),
closeConn: err,
})
continue
} }
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
@ -796,10 +813,13 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
buffer := t.bufferPool.get() pool := t.bufferPool
buffer.Reset() if pool == nil {
buffer.Write(f.Data()) // Note that this is only supposed to be nil in tests. Otherwise, stream is
s.write(recvMsg{buffer: buffer}) // always initialized with a BufferPool.
pool = mem.DefaultBufferPool()
}
s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)})
} }
} }
if f.StreamEnded() { if f.StreamEnded() {
@ -842,7 +862,7 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) {
} }
return nil return nil
}) })
t.controlBuf.executeAndPut(func(any) bool { t.controlBuf.executeAndPut(func() bool {
for _, f := range updateFuncs { for _, f := range updateFuncs {
f() f()
} }
@ -996,12 +1016,13 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
} }
headerFields = appendHeaderFieldsFromMD(headerFields, s.header) headerFields = appendHeaderFieldsFromMD(headerFields, s.header)
success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{ hf := &headerFrame{
streamID: s.id, streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: false, endStream: false,
onWrite: t.setResetPingStrikes, onWrite: t.setResetPingStrikes,
}) }
success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf)
if !success { if !success {
if err != nil { if err != nil {
return err return err
@ -1071,7 +1092,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
onWrite: t.setResetPingStrikes, onWrite: t.setResetPingStrikes,
} }
success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) success, err := t.controlBuf.executeAndPut(func() bool {
return t.checkForHeaderListSize(trailingHeader)
}, nil)
if !success { if !success {
if err != nil { if err != nil {
return err return err
@ -1094,27 +1117,37 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Server) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
reader := data.Reader()
if !s.isHeaderSent() { // Headers haven't been written yet. if !s.isHeaderSent() { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil { if err := t.WriteHeader(s, nil); err != nil {
_ = reader.Close()
return err return err
} }
} else { } else {
// Writing headers checks for this condition. // Writing headers checks for this condition.
if s.getState() == streamDone { if s.getState() == streamDone {
_ = reader.Close()
return t.streamContextErr(s) return t.streamContextErr(s)
} }
} }
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
h: hdr, h: hdr,
d: data, reader: reader,
onEachWrite: t.setResetPingStrikes, onEachWrite: t.setResetPingStrikes,
} }
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + df.reader.Remaining())); err != nil {
_ = reader.Close()
return t.streamContextErr(s) return t.streamContextErr(s)
} }
return t.controlBuf.put(df) if err := t.controlBuf.put(df); err != nil {
_ = reader.Close()
return err
}
return nil
} }
// keepalive running in a separate goroutine does the following: // keepalive running in a separate goroutine does the following:
@ -1190,12 +1223,12 @@ func (t *http2Server) keepalive() {
continue continue
} }
if outstandingPing && kpTimeoutLeft <= 0 { if outstandingPing && kpTimeoutLeft <= 0 {
t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Time)) t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Timeout))
return return
} }
if !outstandingPing { if !outstandingPing {
if channelz.IsOn() { if channelz.IsOn() {
atomic.AddInt64(&t.czData.kpCount, 1) t.channelz.SocketMetrics.KeepAlivesSent.Add(1)
} }
t.controlBuf.put(p) t.controlBuf.put(p)
kpTimeoutLeft = t.kp.Timeout kpTimeoutLeft = t.kp.Timeout
@ -1235,7 +1268,7 @@ func (t *http2Server) Close(err error) {
if err := t.conn.Close(); err != nil && t.logger.V(logLevel) { if err := t.conn.Close(); err != nil && t.logger.V(logLevel) {
t.logger.Infof("Error closing underlying net.Conn during Close: %v", err) t.logger.Infof("Error closing underlying net.Conn during Close: %v", err)
} }
channelz.RemoveEntry(t.channelzID) channelz.RemoveEntry(t.channelz.ID)
// Cancel all active streams. // Cancel all active streams.
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
@ -1256,9 +1289,9 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
if channelz.IsOn() { if channelz.IsOn() {
if eosReceived { if eosReceived {
atomic.AddInt64(&t.czData.streamsSucceeded, 1) t.channelz.SocketMetrics.StreamsSucceeded.Add(1)
} else { } else {
atomic.AddInt64(&t.czData.streamsFailed, 1) t.channelz.SocketMetrics.StreamsFailed.Add(1)
} }
} }
} }
@ -1375,38 +1408,21 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
return false, nil return false, nil
} }
func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { func (t *http2Server) socketMetrics() *channelz.EphemeralSocketMetrics {
s := channelz.SocketInternalMetric{ return &channelz.EphemeralSocketMetrics{
StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted),
StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded),
StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed),
MessagesSent: atomic.LoadInt64(&t.czData.msgSent),
MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv),
KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount),
LastRemoteStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)),
LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)),
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()), LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn), RemoteFlowControlWindow: t.getOutFlowWindow(),
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
// RemoteName :
} }
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
return &s
} }
func (t *http2Server) IncrMsgSent() { func (t *http2Server) IncrMsgSent() {
atomic.AddInt64(&t.czData.msgSent, 1) t.channelz.SocketMetrics.MessagesSent.Add(1)
atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastMessageSentTimestamp.Add(1)
} }
func (t *http2Server) IncrMsgRecv() { func (t *http2Server) IncrMsgRecv() {
atomic.AddInt64(&t.czData.msgRecv, 1) t.channelz.SocketMetrics.MessagesReceived.Add(1)
atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano()) t.channelz.SocketMetrics.LastMessageReceivedTimestamp.Add(1)
} }
func (t *http2Server) getOutFlowWindow() int64 { func (t *http2Server) getOutFlowWindow() int64 {
@ -1439,7 +1455,7 @@ func getJitter(v time.Duration) time.Duration {
} }
// Generate a jitter between +/- 10% of the value. // Generate a jitter between +/- 10% of the value.
r := int64(v / 10) r := int64(v / 10)
j := grpcrand.Int63n(2*r) - r j := rand.Int63n(2*r) - r
return time.Duration(j) return time.Duration(j)
} }

View File

@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
return w return w
} }
func (w *bufWriter) Write(b []byte) (n int, err error) { func (w *bufWriter) Write(b []byte) (int, error) {
if w.err != nil { if w.err != nil {
return 0, w.err return 0, w.err
} }
if w.batchSize == 0 { // Buffer has been disabled. if w.batchSize == 0 { // Buffer has been disabled.
n, err = w.conn.Write(b) n, err := w.conn.Write(b)
return n, toIOError(err) return n, toIOError(err)
} }
if w.buf == nil { if w.buf == nil {
b := w.pool.Get().(*[]byte) b := w.pool.Get().(*[]byte)
w.buf = *b w.buf = *b
} }
written := 0
for len(b) > 0 { for len(b) > 0 {
nn := copy(w.buf[w.offset:], b) copied := copy(w.buf[w.offset:], b)
b = b[nn:] b = b[copied:]
w.offset += nn written += copied
n += nn w.offset += copied
if w.offset >= w.batchSize { if w.offset < w.batchSize {
err = w.flushKeepBuffer() continue
}
if err := w.flushKeepBuffer(); err != nil {
return written, err
} }
} }
return n, err return written, nil
} }
func (w *bufWriter) Flush() error { func (w *bufWriter) Flush() error {
@ -418,10 +422,9 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu
return f return f
} }
func getWriteBufferPool(writeBufferSize int) *sync.Pool { func getWriteBufferPool(size int) *sync.Pool {
writeBufferMutex.Lock() writeBufferMutex.Lock()
defer writeBufferMutex.Unlock() defer writeBufferMutex.Unlock()
size := writeBufferSize * 2
pool, ok := writeBufferPoolMap[size] pool, ok := writeBufferPoolMap[size]
if ok { if ok {
return pool return pool

View File

@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
} }
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
} }
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
} }
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy // proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy

View File

@ -22,12 +22,12 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -36,6 +36,7 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
@ -46,32 +47,10 @@ import (
const logLevel = 2 const logLevel = 2
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport // recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed. // protocol specific info has been removed.
type recvMsg struct { type recvMsg struct {
buffer *bytes.Buffer buffer mem.Buffer
// nil: received some data // nil: received some data
// io.EOF: stream is completed. data is nil. // io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil. // other non-nil error: transport failure. data is nil.
@ -101,6 +80,9 @@ func newRecvBuffer() *recvBuffer {
func (b *recvBuffer) put(r recvMsg) { func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock() b.mu.Lock()
if b.err != nil { if b.err != nil {
// drop the buffer on the floor. Since b.err is not nil, any subsequent reads
// will always return an error, making this buffer inaccessible.
r.buffer.Free()
b.mu.Unlock() b.mu.Unlock()
// An error had occurred earlier, don't accept more // An error had occurred earlier, don't accept more
// data or errors. // data or errors.
@ -147,45 +129,70 @@ type recvBufferReader struct {
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance). ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
last *bytes.Buffer // Stores the remaining data in the previous calls. last mem.Buffer // Stores the remaining data in the previous calls.
err error err error
freeBuffer func(*bytes.Buffer)
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to func (r *recvBufferReader) ReadHeader(header []byte) (n int, err error) {
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }
if r.last != nil { if r.last != nil {
// Read remaining data left in last call. n, r.last = mem.ReadUnsafe(header, r.last)
copied, _ := r.last.Read(p) return n, nil
if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil
} }
if r.closeStream != nil { if r.closeStream != nil {
n, r.err = r.readClient(p) n, r.err = r.readHeaderClient(header)
} else { } else {
n, r.err = r.read(p) n, r.err = r.readHeader(header)
} }
return n, r.err return n, r.err
} }
func (r *recvBufferReader) read(p []byte) (n int, err error) { // Read reads the next n bytes from last. If last is drained, it tries to read
// additional data from recv. It blocks if there no additional data available in
// recv. If Read returns any non-nil error, it will continue to return that
// error.
func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) {
if r.err != nil {
return nil, r.err
}
if r.last != nil {
buf = r.last
if r.last.Len() > n {
buf, r.last = mem.SplitUnsafe(buf, n)
} else {
r.last = nil
}
return buf, nil
}
if r.closeStream != nil {
buf, r.err = r.readClient(n)
} else {
buf, r.err = r.read(n)
}
return buf, r.err
}
func (r *recvBufferReader) readHeader(header []byte) (n int, err error) {
select { select {
case <-r.ctxDone: case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get(): case m := <-r.recv.get():
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
} }
} }
func (r *recvBufferReader) readClient(p []byte) (n int, err error) { func (r *recvBufferReader) read(n int) (buf mem.Buffer, err error) {
select {
case <-r.ctxDone:
return nil, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readHeaderClient(header []byte) (n int, err error) {
// If the context is canceled, then closes the stream with nil metadata. // If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg. // closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error. // r.readAdditional acts on that message and returns the necessary error.
@ -206,25 +213,67 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
// faster. // faster.
r.closeStream(ContextErr(r.ctx.Err())) r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get() m := <-r.recv.get()
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
case m := <-r.recv.get(): case m := <-r.recv.get():
return r.readAdditional(m, p) return r.readHeaderAdditional(m, header)
} }
} }
func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) { func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) {
// If the context is canceled, then closes the stream with nil metadata.
// closeStream writes its error parameter to r.recv as a recvMsg.
// r.readAdditional acts on that message and returns the necessary error.
select {
case <-r.ctxDone:
// Note that this adds the ctx error to the end of recv buffer, and
// reads from the head. This will delay the error until recv buffer is
// empty, thus will delay ctx cancellation in Recv().
//
// It's done this way to fix a race between ctx cancel and trailer. The
// race was, stream.Recv() may return ctx error if ctxDone wins the
// race, but stream.Trailer() may return a non-nil md because the stream
// was not marked as done when trailer is received. This closeStream
// call will mark stream as done, thus fix the race.
//
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get()
return r.readAdditional(m, n)
case m := <-r.recv.get():
return r.readAdditional(m, n)
}
}
func (r *recvBufferReader) readHeaderAdditional(m recvMsg, header []byte) (n int, err error) {
r.recv.load() r.recv.load()
if m.err != nil { if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
}
return 0, m.err return 0, m.err
} }
copied, _ := m.buffer.Read(p)
if m.buffer.Len() == 0 { n, r.last = mem.ReadUnsafe(header, m.buffer)
r.freeBuffer(m.buffer)
r.last = nil return n, nil
} else { }
r.last = m.buffer
func (r *recvBufferReader) readAdditional(m recvMsg, n int) (b mem.Buffer, err error) {
r.recv.load()
if m.err != nil {
if m.buffer != nil {
m.buffer.Free()
} }
return copied, nil return nil, m.err
}
if m.buffer.Len() > n {
m.buffer, r.last = mem.SplitUnsafe(m.buffer, n)
}
return m.buffer, nil
} }
type streamState uint32 type streamState uint32
@ -240,7 +289,7 @@ const (
type Stream struct { type Stream struct {
id uint32 id uint32
st ServerTransport // nil for client side Stream st ServerTransport // nil for client side Stream
ct *http2Client // nil for server side Stream ct ClientTransport // nil for server side Stream
ctx context.Context // the associated context of the stream ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side. done chan struct{} // closed at the end of stream to unblock writers. On the client side.
@ -250,7 +299,7 @@ type Stream struct {
recvCompress string recvCompress string
sendCompress string sendCompress string
buf *recvBuffer buf *recvBuffer
trReader io.Reader trReader *transportReader
fc *inFlow fc *inFlow
wq *writeQuota wq *writeQuota
@ -303,7 +352,7 @@ func (s *Stream) isHeaderSent() bool {
} }
// updateHeaderSent updates headerSent and returns true // updateHeaderSent updates headerSent and returns true
// if it was alreay set. It is valid only on server-side. // if it was already set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool { func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1 return atomic.SwapUint32(&s.headerSent, 1) == 1
} }
@ -362,8 +411,12 @@ func (s *Stream) SendCompress() string {
// ClientAdvertisedCompressors returns the compressor names advertised by the // ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header. // client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string { func (s *Stream) ClientAdvertisedCompressors() []string {
return s.clientAdvertisedCompressors values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
} }
// Done returns a channel which is closed when it receives the final status // Done returns a channel which is closed when it receives the final status
@ -403,7 +456,7 @@ func (s *Stream) TrailersOnly() bool {
return s.noHeaders return s.noHeaders
} }
// Trailer returns the cached trailer metedata. Note that if it is not called // Trailer returns the cached trailer metadata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client // after the entire stream is done, it could return an empty MD. Client
// side only. // side only.
// It can be safely read only after stream has ended that is either read // It can be safely read only after stream has ended that is either read
@ -494,36 +547,87 @@ func (s *Stream) write(m recvMsg) {
s.buf.put(m) s.buf.put(m)
} }
// Read reads all p bytes from the wire for this stream. func (s *Stream) ReadHeader(header []byte) (err error) {
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier // Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil { if er := s.trReader.er; er != nil {
return 0, er return er
} }
s.requestRead(len(p)) s.requestRead(len(header))
return io.ReadFull(s.trReader, p) for len(header) != 0 {
n, err := s.trReader.ReadHeader(header)
header = header[n:]
if len(header) == 0 {
err = nil
}
if err != nil {
if n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
}
return nil
} }
// tranportReader reads all the data available for this Stream from the transport and // Read reads n bytes from the wire for this stream.
func (s *Stream) Read(n int) (data mem.BufferSlice, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.er; er != nil {
return nil, er
}
s.requestRead(n)
for n != 0 {
buf, err := s.trReader.Read(n)
var bufLen int
if buf != nil {
bufLen = buf.Len()
}
n -= bufLen
if n == 0 {
err = nil
}
if err != nil {
if bufLen > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
data.Free()
return nil, err
}
data = append(data, buf)
}
return data, nil
}
// transportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream. // passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if // The error is io.EOF when the stream is done or another non-nil error if
// the stream broke. // the stream broke.
type transportReader struct { type transportReader struct {
reader io.Reader reader *recvBufferReader
// The handler to control the window update procedure for both this // The handler to control the window update procedure for both this
// particular stream and the associated transport. // particular stream and the associated transport.
windowHandler func(int) windowHandler func(int)
er error er error
} }
func (t *transportReader) Read(p []byte) (n int, err error) { func (t *transportReader) ReadHeader(header []byte) (int, error) {
n, err = t.reader.Read(p) n, err := t.reader.ReadHeader(header)
if err != nil { if err != nil {
t.er = err t.er = err
return return 0, err
} }
t.windowHandler(n) t.windowHandler(len(header))
return return n, nil
}
func (t *transportReader) Read(n int) (mem.Buffer, error) {
buf, err := t.reader.Read(n)
if err != nil {
t.er = err
return buf, err
}
t.windowHandler(buf.Len())
return buf, nil
} }
// BytesReceived indicates whether any bytes have been received on this stream. // BytesReceived indicates whether any bytes have been received on this stream.
@ -566,9 +670,10 @@ type ServerConfig struct {
WriteBufferSize int WriteBufferSize int
ReadBufferSize int ReadBufferSize int
SharedWriteBuffer bool SharedWriteBuffer bool
ChannelzParentID *channelz.Identifier ChannelzParent *channelz.Server
MaxHeaderListSize *uint32 MaxHeaderListSize *uint32
HeaderTableSize *uint32 HeaderTableSize *uint32
BufferPool mem.BufferPool
} }
// ConnectOptions covers all relevant options for communicating with the server. // ConnectOptions covers all relevant options for communicating with the server.
@ -601,12 +706,14 @@ type ConnectOptions struct {
ReadBufferSize int ReadBufferSize int
// SharedWriteBuffer indicates whether connections should reuse write buffer // SharedWriteBuffer indicates whether connections should reuse write buffer
SharedWriteBuffer bool SharedWriteBuffer bool
// ChannelzParentID sets the addrConn id which initiate the creation of this client transport. // ChannelzParent sets the addrConn id which initiated the creation of this client transport.
ChannelzParentID *channelz.Identifier ChannelzParent *channelz.SubChannel
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32 MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used. // UseProxy specifies if a proxy should be used.
UseProxy bool UseProxy bool
// The mem.BufferPool to use when reading/writing to the wire.
BufferPool mem.BufferPool
} }
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
@ -668,7 +775,7 @@ type ClientTransport interface {
// Write sends the data for the given stream. A nil stream indicates // Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole. // the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
// NewStream creates a Stream for an RPC. // NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@ -720,7 +827,7 @@ type ServerTransport interface {
// Write sends the data for the given stream. // Write sends the data for the given stream.
// Write may not be called on all streams. // Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is // WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs. // the final call made on a stream and always occurs.
@ -793,7 +900,7 @@ var (
// connection is draining. This could be caused by goaway or balancer // connection is draining. This could be caused by goaway or balancer
// removing the address. // removing the address.
errStreamDrain = status.Error(codes.Unavailable, "the connection is draining") errStreamDrain = status.Error(codes.Unavailable, "the connection is draining")
// errStreamDone is returned from write at the client side to indiacte application // errStreamDone is returned from write at the client side to indicate application
// layer of an error. // layer of an error.
errStreamDone = errors.New("the stream is done") errStreamDone = errors.New("the stream is done")
// StatusGoAway indicates that the server sent a GOAWAY that included this // StatusGoAway indicates that the server sent a GOAWAY that included this
@ -815,30 +922,6 @@ const (
GoAwayTooManyPings GoAwayReason = 2 GoAwayTooManyPings GoAwayReason = 2
) )
// channelzData is used to store channelz related data for http2Client and http2Server.
// These fields cannot be embedded in the original structs (e.g. http2Client), since to do atomic
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment.
type channelzData struct {
kpCount int64
// The number of streams that have started, including already finished ones.
streamsStarted int64
// Client side: The number of streams that have ended successfully by receiving
// EoS bit set frame from server.
// Server side: The number of streams that have ended successfully by sending
// frame with EoS bit set.
streamsSucceeded int64
streamsFailed int64
// lastStreamCreatedTime stores the timestamp that the last stream gets created. It is of int64 type
// instead of time.Time since it's more costly to atomically update time.Time variable than int64
// variable. The same goes for lastMsgSentTime and lastMsgRecvTime.
lastStreamCreatedTime int64
msgSent int64
msgRecv int64
lastMsgSentTime int64
lastMsgRecvTime int64
}
// ContextErr converts the error from context package into a status error. // ContextErr converts the error from context package into a status error.
func ContextErr(err error) error { func ContextErr(err error) error {
switch err { switch err {

View File

@ -1,40 +0,0 @@
/*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package internal
import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/resolver"
)
// handshakeClusterNameKey is the type used as the key to store cluster name in
// the Attributes field of resolver.Address.
type handshakeClusterNameKey struct{}
// SetXDSHandshakeClusterName returns a copy of addr in which the Attributes field
// is updated with the cluster name.
func SetXDSHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address {
addr.Attributes = addr.Attributes.WithValue(handshakeClusterNameKey{}, clusterName)
return addr
}
// GetXDSHandshakeClusterName returns cluster name stored in attr.
func GetXDSHandshakeClusterName(attr *attributes.Attributes) (string, bool) {
v := attr.Value(handshakeClusterNameKey{})
name, ok := v.(string)
return name, ok
}

194
vendor/google.golang.org/grpc/mem/buffer_pool.go generated vendored Normal file
View File

@ -0,0 +1,194 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"sort"
"sync"
"google.golang.org/grpc/internal"
)
// BufferPool is a pool of buffers that can be shared and reused, resulting in
// decreased memory allocation.
type BufferPool interface {
// Get returns a buffer with specified length from the pool.
Get(length int) *[]byte
// Put returns a buffer to the pool.
Put(*[]byte)
}
var defaultBufferPoolSizes = []int{
256,
4 << 10, // 4KB (go page size)
16 << 10, // 16KB (max HTTP/2 frame size used by gRPC)
32 << 10, // 32KB (default buffer size for io.Copy)
1 << 20, // 1MB
}
var defaultBufferPool BufferPool
func init() {
defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...)
internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) {
defaultBufferPool = pool
}
internal.SetBufferPoolingThresholdForTesting = func(threshold int) {
bufferPoolingThreshold = threshold
}
}
// DefaultBufferPool returns the current default buffer pool. It is a BufferPool
// created with NewBufferPool that uses a set of default sizes optimized for
// expected workflows.
func DefaultBufferPool() BufferPool {
return defaultBufferPool
}
// NewTieredBufferPool returns a BufferPool implementation that uses multiple
// underlying pools of the given pool sizes.
func NewTieredBufferPool(poolSizes ...int) BufferPool {
sort.Ints(poolSizes)
pools := make([]*sizedBufferPool, len(poolSizes))
for i, s := range poolSizes {
pools[i] = newSizedBufferPool(s)
}
return &tieredBufferPool{
sizedPools: pools,
}
}
// tieredBufferPool implements the BufferPool interface with multiple tiers of
// buffer pools for different sizes of buffers.
type tieredBufferPool struct {
sizedPools []*sizedBufferPool
fallbackPool simpleBufferPool
}
func (p *tieredBufferPool) Get(size int) *[]byte {
return p.getPool(size).Get(size)
}
func (p *tieredBufferPool) Put(buf *[]byte) {
p.getPool(cap(*buf)).Put(buf)
}
func (p *tieredBufferPool) getPool(size int) BufferPool {
poolIdx := sort.Search(len(p.sizedPools), func(i int) bool {
return p.sizedPools[i].defaultSize >= size
})
if poolIdx == len(p.sizedPools) {
return &p.fallbackPool
}
return p.sizedPools[poolIdx]
}
// sizedBufferPool is a BufferPool implementation that is optimized for specific
// buffer sizes. For example, HTTP/2 frames within gRPC have a default max size
// of 16kb and a sizedBufferPool can be configured to only return buffers with a
// capacity of 16kb. Note that however it does not support returning larger
// buffers and in fact panics if such a buffer is requested. Because of this,
// this BufferPool implementation is not meant to be used on its own and rather
// is intended to be embedded in a tieredBufferPool such that Get is only
// invoked when the required size is smaller than or equal to defaultSize.
type sizedBufferPool struct {
pool sync.Pool
defaultSize int
}
func (p *sizedBufferPool) Get(size int) *[]byte {
buf := p.pool.Get().(*[]byte)
b := *buf
clear(b[:cap(b)])
*buf = b[:size]
return buf
}
func (p *sizedBufferPool) Put(buf *[]byte) {
if cap(*buf) < p.defaultSize {
// Ignore buffers that are too small to fit in the pool. Otherwise, when
// Get is called it will panic as it tries to index outside the bounds
// of the buffer.
return
}
p.pool.Put(buf)
}
func newSizedBufferPool(size int) *sizedBufferPool {
return &sizedBufferPool{
pool: sync.Pool{
New: func() any {
buf := make([]byte, size)
return &buf
},
},
defaultSize: size,
}
}
var _ BufferPool = (*simpleBufferPool)(nil)
// simpleBufferPool is an implementation of the BufferPool interface that
// attempts to pool buffers with a sync.Pool. When Get is invoked, it tries to
// acquire a buffer from the pool but if that buffer is too small, it returns it
// to the pool and creates a new one.
type simpleBufferPool struct {
pool sync.Pool
}
func (p *simpleBufferPool) Get(size int) *[]byte {
bs, ok := p.pool.Get().(*[]byte)
if ok && cap(*bs) >= size {
*bs = (*bs)[:size]
return bs
}
// A buffer was pulled from the pool, but it is too small. Put it back in
// the pool and create one large enough.
if ok {
p.pool.Put(bs)
}
b := make([]byte, size)
return &b
}
func (p *simpleBufferPool) Put(buf *[]byte) {
p.pool.Put(buf)
}
var _ BufferPool = NopBufferPool{}
// NopBufferPool is a buffer pool that returns new buffers without pooling.
type NopBufferPool struct{}
// Get returns a buffer with specified length from the pool.
func (NopBufferPool) Get(length int) *[]byte {
b := make([]byte, length)
return &b
}
// Put returns a buffer to the pool.
func (NopBufferPool) Put(*[]byte) {
}

226
vendor/google.golang.org/grpc/mem/buffer_slice.go generated vendored Normal file
View File

@ -0,0 +1,226 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package mem
import (
"io"
)
// BufferSlice offers a means to represent data that spans one or more Buffer
// instances. A BufferSlice is meant to be immutable after creation, and methods
// like Ref create and return copies of the slice. This is why all methods have
// value receivers rather than pointer receivers.
//
// Note that any of the methods that read the underlying buffers such as Ref,
// Len or CopyTo etc., will panic if any underlying buffers have already been
// freed. It is recommended to not directly interact with any of the underlying
// buffers directly, rather such interactions should be mediated through the
// various methods on this type.
//
// By convention, any APIs that return (mem.BufferSlice, error) should reduce
// the burden on the caller by never returning a mem.BufferSlice that needs to
// be freed if the error is non-nil, unless explicitly stated.
type BufferSlice []Buffer
// Len returns the sum of the length of all the Buffers in this slice.
//
// # Warning
//
// Invoking the built-in len on a BufferSlice will return the number of buffers
// in the slice, and *not* the value returned by this function.
func (s BufferSlice) Len() int {
var length int
for _, b := range s {
length += b.Len()
}
return length
}
// Ref invokes Ref on each buffer in the slice.
func (s BufferSlice) Ref() {
for _, b := range s {
b.Ref()
}
}
// Free invokes Buffer.Free() on each Buffer in the slice.
func (s BufferSlice) Free() {
for _, b := range s {
b.Free()
}
}
// CopyTo copies each of the underlying Buffer's data into the given buffer,
// returning the number of bytes copied. Has the same semantics as the copy
// builtin in that it will copy as many bytes as it can, stopping when either dst
// is full or s runs out of data, returning the minimum of s.Len() and len(dst).
func (s BufferSlice) CopyTo(dst []byte) int {
off := 0
for _, b := range s {
off += copy(dst[off:], b.ReadOnlyData())
}
return off
}
// Materialize concatenates all the underlying Buffer's data into a single
// contiguous buffer using CopyTo.
func (s BufferSlice) Materialize() []byte {
l := s.Len()
if l == 0 {
return nil
}
out := make([]byte, l)
s.CopyTo(out)
return out
}
// MaterializeToBuffer functions like Materialize except that it writes the data
// to a single Buffer pulled from the given BufferPool.
//
// As a special case, if the input BufferSlice only actually has one Buffer, this
// function simply increases the refcount before returning said Buffer. Freeing this
// buffer won't release it until the BufferSlice is itself released.
func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer {
if len(s) == 1 {
s[0].Ref()
return s[0]
}
sLen := s.Len()
if sLen == 0 {
return emptyBuffer{}
}
buf := pool.Get(sLen)
s.CopyTo(*buf)
return NewBuffer(buf, pool)
}
// Reader returns a new Reader for the input slice after taking references to
// each underlying buffer.
func (s BufferSlice) Reader() Reader {
s.Ref()
return &sliceReader{
data: s,
len: s.Len(),
}
}
// Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface
// with other parts systems. It also provides an additional convenience method
// Remaining(), which returns the number of unread bytes remaining in the slice.
// Buffers will be freed as they are read.
type Reader interface {
io.Reader
io.ByteReader
// Close frees the underlying BufferSlice and never returns an error. Subsequent
// calls to Read will return (0, io.EOF).
Close() error
// Remaining returns the number of unread bytes remaining in the slice.
Remaining() int
}
type sliceReader struct {
data BufferSlice
len int
// The index into data[0].ReadOnlyData().
bufferIdx int
}
func (r *sliceReader) Remaining() int {
return r.len
}
func (r *sliceReader) Close() error {
r.data.Free()
r.data = nil
r.len = 0
return nil
}
func (r *sliceReader) freeFirstBufferIfEmpty() bool {
if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) {
return false
}
r.data[0].Free()
r.data = r.data[1:]
r.bufferIdx = 0
return true
}
func (r *sliceReader) Read(buf []byte) (n int, _ error) {
if r.len == 0 {
return 0, io.EOF
}
for len(buf) != 0 && r.len != 0 {
// Copy as much as possible from the first Buffer in the slice into the
// given byte slice.
data := r.data[0].ReadOnlyData()
copied := copy(buf, data[r.bufferIdx:])
r.len -= copied // Reduce len by the number of bytes copied.
r.bufferIdx += copied // Increment the buffer index.
n += copied // Increment the total number of bytes read.
buf = buf[copied:] // Shrink the given byte slice.
// If we have copied all the data from the first Buffer, free it and advance to
// the next in the slice.
r.freeFirstBufferIfEmpty()
}
return n, nil
}
func (r *sliceReader) ReadByte() (byte, error) {
if r.len == 0 {
return 0, io.EOF
}
// There may be any number of empty buffers in the slice, clear them all until a
// non-empty buffer is reached. This is guaranteed to exit since r.len is not 0.
for r.freeFirstBufferIfEmpty() {
}
b := r.data[0].ReadOnlyData()[r.bufferIdx]
r.len--
r.bufferIdx++
// Free the first buffer in the slice if the last byte was read
r.freeFirstBufferIfEmpty()
return b, nil
}
var _ io.Writer = (*writer)(nil)
type writer struct {
buffers *BufferSlice
pool BufferPool
}
func (w *writer) Write(p []byte) (n int, err error) {
b := Copy(p, w.pool)
*w.buffers = append(*w.buffers, b)
return b.Len(), nil
}
// NewWriter wraps the given BufferSlice and BufferPool to implement the
// io.Writer interface. Every call to Write copies the contents of the given
// buffer into a new Buffer pulled from the given pool and the Buffer is added to
// the given BufferSlice.
func NewWriter(buffers *BufferSlice, pool BufferPool) io.Writer {
return &writer{buffers: buffers, pool: pool}
}

252
vendor/google.golang.org/grpc/mem/buffers.go generated vendored Normal file
View File

@ -0,0 +1,252 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package mem provides utilities that facilitate memory reuse in byte slices
// that are used as buffers.
//
// # Experimental
//
// Notice: All APIs in this package are EXPERIMENTAL and may be changed or
// removed in a later release.
package mem
import (
"fmt"
"sync"
"sync/atomic"
)
// A Buffer represents a reference counted piece of data (in bytes) that can be
// acquired by a call to NewBuffer() or Copy(). A reference to a Buffer may be
// released by calling Free(), which invokes the free function given at creation
// only after all references are released.
//
// Note that a Buffer is not safe for concurrent access and instead each
// goroutine should use its own reference to the data, which can be acquired via
// a call to Ref().
//
// Attempts to access the underlying data after releasing the reference to the
// Buffer will panic.
type Buffer interface {
// ReadOnlyData returns the underlying byte slice. Note that it is undefined
// behavior to modify the contents of this slice in any way.
ReadOnlyData() []byte
// Ref increases the reference counter for this Buffer.
Ref()
// Free decrements this Buffer's reference counter and frees the underlying
// byte slice if the counter reaches 0 as a result of this call.
Free()
// Len returns the Buffer's size.
Len() int
split(n int) (left, right Buffer)
read(buf []byte) (int, Buffer)
}
var (
bufferPoolingThreshold = 1 << 10
bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }}
refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }}
)
func IsBelowBufferPoolingThreshold(size int) bool {
return size <= bufferPoolingThreshold
}
type buffer struct {
origData *[]byte
data []byte
refs *atomic.Int32
pool BufferPool
}
func newBuffer() *buffer {
return bufferObjectPool.Get().(*buffer)
}
// NewBuffer creates a new Buffer from the given data, initializing the reference
// counter to 1. The data will then be returned to the given pool when all
// references to the returned Buffer are released. As a special case to avoid
// additional allocations, if the given buffer pool is nil, the returned buffer
// will be a "no-op" Buffer where invoking Buffer.Free() does nothing and the
// underlying data is never freed.
//
// Note that the backing array of the given data is not copied.
func NewBuffer(data *[]byte, pool BufferPool) Buffer {
if pool == nil || IsBelowBufferPoolingThreshold(len(*data)) {
return (SliceBuffer)(*data)
}
b := newBuffer()
b.origData = data
b.data = *data
b.pool = pool
b.refs = refObjectPool.Get().(*atomic.Int32)
b.refs.Add(1)
return b
}
// Copy creates a new Buffer from the given data, initializing the reference
// counter to 1.
//
// It acquires a []byte from the given pool and copies over the backing array
// of the given data. The []byte acquired from the pool is returned to the
// pool when all references to the returned Buffer are released.
func Copy(data []byte, pool BufferPool) Buffer {
if IsBelowBufferPoolingThreshold(len(data)) {
buf := make(SliceBuffer, len(data))
copy(buf, data)
return buf
}
buf := pool.Get(len(data))
copy(*buf, data)
return NewBuffer(buf, pool)
}
func (b *buffer) ReadOnlyData() []byte {
if b.refs == nil {
panic("Cannot read freed buffer")
}
return b.data
}
func (b *buffer) Ref() {
if b.refs == nil {
panic("Cannot ref freed buffer")
}
b.refs.Add(1)
}
func (b *buffer) Free() {
if b.refs == nil {
panic("Cannot free freed buffer")
}
refs := b.refs.Add(-1)
switch {
case refs > 0:
return
case refs == 0:
if b.pool != nil {
b.pool.Put(b.origData)
}
refObjectPool.Put(b.refs)
b.origData = nil
b.data = nil
b.refs = nil
b.pool = nil
bufferObjectPool.Put(b)
default:
panic("Cannot free freed buffer")
}
}
func (b *buffer) Len() int {
return len(b.ReadOnlyData())
}
func (b *buffer) split(n int) (Buffer, Buffer) {
if b.refs == nil {
panic("Cannot split freed buffer")
}
b.refs.Add(1)
split := newBuffer()
split.origData = b.origData
split.data = b.data[n:]
split.refs = b.refs
split.pool = b.pool
b.data = b.data[:n]
return b, split
}
func (b *buffer) read(buf []byte) (int, Buffer) {
if b.refs == nil {
panic("Cannot read freed buffer")
}
n := copy(buf, b.data)
if n == len(b.data) {
b.Free()
return n, nil
}
b.data = b.data[n:]
return n, b
}
// String returns a string representation of the buffer. May be used for
// debugging purposes.
func (b *buffer) String() string {
return fmt.Sprintf("mem.Buffer(%p, data: %p, length: %d)", b, b.ReadOnlyData(), len(b.ReadOnlyData()))
}
func ReadUnsafe(dst []byte, buf Buffer) (int, Buffer) {
return buf.read(dst)
}
// SplitUnsafe modifies the receiver to point to the first n bytes while it
// returns a new reference to the remaining bytes. The returned Buffer functions
// just like a normal reference acquired using Ref().
func SplitUnsafe(buf Buffer, n int) (left, right Buffer) {
return buf.split(n)
}
type emptyBuffer struct{}
func (e emptyBuffer) ReadOnlyData() []byte {
return nil
}
func (e emptyBuffer) Ref() {}
func (e emptyBuffer) Free() {}
func (e emptyBuffer) Len() int {
return 0
}
func (e emptyBuffer) split(n int) (left, right Buffer) {
return e, e
}
func (e emptyBuffer) read(buf []byte) (int, Buffer) {
return 0, e
}
type SliceBuffer []byte
func (s SliceBuffer) ReadOnlyData() []byte { return s }
func (s SliceBuffer) Ref() {}
func (s SliceBuffer) Free() {}
func (s SliceBuffer) Len() int { return len(s) }
func (s SliceBuffer) split(n int) (left, right Buffer) {
return s[:n], s[n:]
}
func (s SliceBuffer) read(buf []byte) (int, Buffer) {
n := copy(buf, s)
if n == len(s) {
return n, nil
}
return n, s[n:]
}

View File

@ -213,11 +213,6 @@ func FromIncomingContext(ctx context.Context) (MD, bool) {
// ValueFromIncomingContext returns the metadata value corresponding to the metadata // ValueFromIncomingContext returns the metadata value corresponding to the metadata
// key from the incoming metadata if it exists. Keys are matched in a case insensitive // key from the incoming metadata if it exists. Keys are matched in a case insensitive
// manner. // manner.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ValueFromIncomingContext(ctx context.Context, key string) []string { func ValueFromIncomingContext(ctx context.Context, key string) []string {
md, ok := ctx.Value(mdIncomingKey{}).(MD) md, ok := ctx.Value(mdIncomingKey{}).(MD)
if !ok { if !ok {
@ -228,7 +223,7 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string {
return copyOf(v) return copyOf(v)
} }
for k, v := range md { for k, v := range md {
// Case insenitive comparison: MD is a map, and there's no guarantee // Case insensitive comparison: MD is a map, and there's no guarantee
// that the MD attached to the context is created using our helper // that the MD attached to the context is created using our helper
// functions. // functions.
if strings.EqualFold(k, key) { if strings.EqualFold(k, key) {

View File

@ -22,7 +22,9 @@ package peer
import ( import (
"context" "context"
"fmt"
"net" "net"
"strings"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
@ -39,6 +41,34 @@ type Peer struct {
AuthInfo credentials.AuthInfo AuthInfo credentials.AuthInfo
} }
// String ensures the Peer types implements the Stringer interface in order to
// allow to print a context with a peerKey value effectively.
func (p *Peer) String() string {
if p == nil {
return "Peer<nil>"
}
sb := &strings.Builder{}
sb.WriteString("Peer{")
if p.Addr != nil {
fmt.Fprintf(sb, "Addr: '%s', ", p.Addr.String())
} else {
fmt.Fprintf(sb, "Addr: <nil>, ")
}
if p.LocalAddr != nil {
fmt.Fprintf(sb, "LocalAddr: '%s', ", p.LocalAddr.String())
} else {
fmt.Fprintf(sb, "LocalAddr: <nil>, ")
}
if p.AuthInfo != nil {
fmt.Fprintf(sb, "AuthInfo: '%s'", p.AuthInfo.AuthType())
} else {
fmt.Fprintf(sb, "AuthInfo: <nil>")
}
sb.WriteString("}")
return sb.String()
}
type peerKey struct{} type peerKey struct{}
// NewContext creates a new context with peer information attached. // NewContext creates a new context with peer information attached.

View File

@ -20,8 +20,9 @@ package grpc
import ( import (
"context" "context"
"fmt"
"io" "io"
"sync" "sync/atomic"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -32,35 +33,43 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// pickerGeneration stores a picker and a channel used to signal that a picker
// newer than this one is available.
type pickerGeneration struct {
// picker is the picker produced by the LB policy. May be nil if a picker
// has never been produced.
picker balancer.Picker
// blockingCh is closed when the picker has been invalidated because there
// is a new one available.
blockingCh chan struct{}
}
// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick // pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick
// actions and unblock when there's a picker update. // actions and unblock when there's a picker update.
type pickerWrapper struct { type pickerWrapper struct {
mu sync.Mutex // If pickerGen holds a nil pointer, the pickerWrapper is closed.
done bool pickerGen atomic.Pointer[pickerGeneration]
blockingCh chan struct{}
picker balancer.Picker
statsHandlers []stats.Handler // to record blocking picker calls statsHandlers []stats.Handler // to record blocking picker calls
} }
func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper { func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper {
return &pickerWrapper{ pw := &pickerWrapper{
blockingCh: make(chan struct{}),
statsHandlers: statsHandlers, statsHandlers: statsHandlers,
} }
pw.pickerGen.Store(&pickerGeneration{
blockingCh: make(chan struct{}),
})
return pw
} }
// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. // updatePicker is called by UpdateState calls from the LB policy. It
// unblocks all blocked pick.
func (pw *pickerWrapper) updatePicker(p balancer.Picker) { func (pw *pickerWrapper) updatePicker(p balancer.Picker) {
pw.mu.Lock() old := pw.pickerGen.Swap(&pickerGeneration{
if pw.done { picker: p,
pw.mu.Unlock() blockingCh: make(chan struct{}),
return })
} close(old.blockingCh)
pw.picker = p
// pw.blockingCh should never be nil.
close(pw.blockingCh)
pw.blockingCh = make(chan struct{})
pw.mu.Unlock()
} }
// doneChannelzWrapper performs the following: // doneChannelzWrapper performs the following:
@ -97,27 +106,24 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
var lastPickErr error var lastPickErr error
for { for {
pw.mu.Lock() pg := pw.pickerGen.Load()
if pw.done { if pg == nil {
pw.mu.Unlock()
return nil, balancer.PickResult{}, ErrClientConnClosing return nil, balancer.PickResult{}, ErrClientConnClosing
} }
if pg.picker == nil {
if pw.picker == nil { ch = pg.blockingCh
ch = pw.blockingCh
} }
if ch == pw.blockingCh { if ch == pg.blockingCh {
// This could happen when either: // This could happen when either:
// - pw.picker is nil (the previous if condition), or // - pw.picker is nil (the previous if condition), or
// - has called pick on the current picker. // - we have already called pick on the current picker.
pw.mu.Unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
var errStr string var errStr string
if lastPickErr != nil { if lastPickErr != nil {
errStr = "latest balancer error: " + lastPickErr.Error() errStr = "latest balancer error: " + lastPickErr.Error()
} else { } else {
errStr = ctx.Err().Error() errStr = fmt.Sprintf("received context error while waiting for new LB policy update: %s", ctx.Err().Error())
} }
switch ctx.Err() { switch ctx.Err() {
case context.DeadlineExceeded: case context.DeadlineExceeded:
@ -144,9 +150,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
} }
} }
ch = pw.blockingCh ch = pg.blockingCh
p := pw.picker p := pg.picker
pw.mu.Unlock()
pickResult, err := p.Pick(info) pickResult, err := p.Pick(info)
if err != nil { if err != nil {
@ -196,24 +201,15 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
} }
func (pw *pickerWrapper) close() { func (pw *pickerWrapper) close() {
pw.mu.Lock() old := pw.pickerGen.Swap(nil)
defer pw.mu.Unlock() close(old.blockingCh)
if pw.done {
return
}
pw.done = true
close(pw.blockingCh)
} }
// reset clears the pickerWrapper and prepares it for being used again when idle // reset clears the pickerWrapper and prepares it for being used again when idle
// mode is exited. // mode is exited.
func (pw *pickerWrapper) reset() { func (pw *pickerWrapper) reset() {
pw.mu.Lock() old := pw.pickerGen.Swap(&pickerGeneration{blockingCh: make(chan struct{})})
defer pw.mu.Unlock() close(old.blockingCh)
if pw.done {
return
}
pw.blockingCh = make(chan struct{})
} }
// dropError is a wrapper error that indicates the LB policy wishes to drop the // dropError is a wrapper error that indicates the LB policy wishes to drop the

View File

@ -20,6 +20,7 @@ package grpc
import ( import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -31,9 +32,10 @@ import (
// later release. // later release.
type PreparedMsg struct { type PreparedMsg struct {
// Struct for preparing msg before sending them // Struct for preparing msg before sending them
encodedData []byte encodedData mem.BufferSlice
hdr []byte hdr []byte
payload []byte payload mem.BufferSlice
pf payloadFormat
} }
// Encode marshalls and compresses the message using the codec and compressor for the stream. // Encode marshalls and compresses the message using the codec and compressor for the stream.
@ -57,11 +59,27 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error {
if err != nil { if err != nil {
return err return err
} }
p.encodedData = data
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp) materializedData := data.Materialize()
data.Free()
p.encodedData = mem.BufferSlice{mem.NewBuffer(&materializedData, nil)}
// TODO: it should be possible to grab the bufferPool from the underlying
// stream implementation with a type cast to its actual type (such as
// addrConnStream) and accessing the buffer pool directly.
var compData mem.BufferSlice
compData, p.pf, err = compress(p.encodedData, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp, mem.DefaultBufferPool())
if err != nil { if err != nil {
return err return err
} }
p.hdr, p.payload = msgHeader(data, compData)
if p.pf.isCompressed() {
materializedCompData := compData.Materialize()
compData.Free()
compData = mem.BufferSlice{mem.NewBuffer(&materializedCompData, nil)}
}
p.hdr, p.payload = msgHeader(p.encodedData, compData, p.pf)
return nil return nil
} }

View File

@ -1,123 +0,0 @@
#!/bin/bash
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eu -o pipefail
WORKDIR=$(mktemp -d)
function finish {
rm -rf "$WORKDIR"
}
trap finish EXIT
export GOBIN=${WORKDIR}/bin
export PATH=${GOBIN}:${PATH}
mkdir -p ${GOBIN}
echo "remove existing generated files"
# grpc_testing_not_regenerate/*.pb.go is not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm -f $(find . -name '*.pb.go' | grep -v 'grpc_testing_not_regenerate')
echo "go install google.golang.org/protobuf/cmd/protoc-gen-go"
(cd test/tools && go install google.golang.org/protobuf/cmd/protoc-gen-go)
echo "go install cmd/protoc-gen-go-grpc"
(cd cmd/protoc-gen-go-grpc && go install .)
echo "git clone https://github.com/grpc/grpc-proto"
git clone --quiet https://github.com/grpc/grpc-proto ${WORKDIR}/grpc-proto
echo "git clone https://github.com/protocolbuffers/protobuf"
git clone --quiet https://github.com/protocolbuffers/protobuf ${WORKDIR}/protobuf
# Pull in code.proto as a proto dependency
mkdir -p ${WORKDIR}/googleapis/google/rpc
echo "curl https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto"
curl --silent https://raw.githubusercontent.com/googleapis/googleapis/master/google/rpc/code.proto > ${WORKDIR}/googleapis/google/rpc/code.proto
mkdir -p ${WORKDIR}/out
# Generates sources without the embed requirement
LEGACY_SOURCES=(
${WORKDIR}/grpc-proto/grpc/binlog/v1/binarylog.proto
${WORKDIR}/grpc-proto/grpc/channelz/v1/channelz.proto
${WORKDIR}/grpc-proto/grpc/health/v1/health.proto
${WORKDIR}/grpc-proto/grpc/lb/v1/load_balancer.proto
profiling/proto/service.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1alpha/reflection.proto
${WORKDIR}/grpc-proto/grpc/reflection/v1/reflection.proto
)
# Generates only the new gRPC Service symbols
SOURCES=(
$(git ls-files --exclude-standard --cached --others "*.proto" | grep -v '^\(profiling/proto/service.proto\|reflection/grpc_reflection_v1alpha/reflection.proto\)$')
${WORKDIR}/grpc-proto/grpc/gcp/altscontext.proto
${WORKDIR}/grpc-proto/grpc/gcp/handshaker.proto
${WORKDIR}/grpc-proto/grpc/gcp/transport_security_common.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls.proto
${WORKDIR}/grpc-proto/grpc/lookup/v1/rls_config.proto
${WORKDIR}/grpc-proto/grpc/testing/*.proto
${WORKDIR}/grpc-proto/grpc/core/*.proto
)
# These options of the form 'Mfoo.proto=bar' instruct the codegen to use an
# import path of 'bar' in the generated code when 'foo.proto' is imported in
# one of the sources.
#
# Note that the protos listed here are all for testing purposes. All protos to
# be used externally should have a go_package option (and they don't need to be
# listed here).
OPTS=Mgrpc/core/stats.proto=google.golang.org/grpc/interop/grpc_testing/core,\
Mgrpc/testing/benchmark_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/stats.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/report_qps_scenario_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/messages.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/worker_service.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/control.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/test.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/payloads.proto=google.golang.org/grpc/interop/grpc_testing,\
Mgrpc/testing/empty.proto=google.golang.org/grpc/interop/grpc_testing
for src in ${SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS}:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
for src in ${LEGACY_SOURCES[@]}; do
echo "protoc ${src}"
protoc --go_out=${OPTS}:${WORKDIR}/out --go-grpc_out=${OPTS},require_unimplemented_servers=false:${WORKDIR}/out \
-I"." \
-I${WORKDIR}/grpc-proto \
-I${WORKDIR}/googleapis \
-I${WORKDIR}/protobuf/src \
${src}
done
# The go_package option in grpc/lookup/v1/rls.proto doesn't match the
# current location. Move it into the right place.
mkdir -p ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
mv ${WORKDIR}/out/google.golang.org/grpc/lookup/grpc_lookup_v1/* ${WORKDIR}/out/google.golang.org/grpc/internal/proto/grpc_lookup_v1
# grpc_testing_not_regenerate/*.pb.go are not re-generated,
# see grpc_testing_not_regenerate/README.md for details.
rm ${WORKDIR}/out/google.golang.org/grpc/reflection/grpc_testing_not_regenerate/*.pb.go
cp -R ${WORKDIR}/out/google.golang.org/grpc/* .

View File

@ -18,19 +18,43 @@
// Package dns implements a dns resolver to be installed as the default resolver // Package dns implements a dns resolver to be installed as the default resolver
// in grpc. // in grpc.
//
// Deprecated: this package is imported by grpc and should not need to be
// imported directly by users.
package dns package dns
import ( import (
"time"
"google.golang.org/grpc/internal/resolver/dns" "google.golang.org/grpc/internal/resolver/dns"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
// SetResolvingTimeout sets the maximum duration for DNS resolution requests.
//
// This function affects the global timeout used by all channels using the DNS
// name resolver scheme.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
//
// The default value is 30 seconds. Setting the timeout too low may result in
// premature timeouts during resolution, while setting it too high may lead to
// unnecessary delays in service discovery. Choose a value appropriate for your
// specific needs and network environment.
func SetResolvingTimeout(timeout time.Duration) {
dns.ResolvingTimeout = timeout
}
// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers. // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
// //
// Deprecated: import grpc and use resolver.Get("dns") instead. // Deprecated: import grpc and use resolver.Get("dns") instead.
func NewBuilder() resolver.Builder { func NewBuilder() resolver.Builder {
return dns.NewBuilder() return dns.NewBuilder()
} }
// SetMinResolutionInterval sets the default minimum interval at which DNS
// re-resolutions are allowed. This helps to prevent excessive re-resolution.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
func SetMinResolutionInterval(d time.Duration) {
dns.MinResolutionInterval = d
}

View File

@ -29,6 +29,7 @@ import (
"google.golang.org/grpc/attributes" "google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
@ -63,16 +64,18 @@ func Get(scheme string) Builder {
} }
// SetDefaultScheme sets the default scheme that will be used. The default // SetDefaultScheme sets the default scheme that will be used. The default
// default scheme is "passthrough". // scheme is initially set to "passthrough".
// //
// NOTE: this function must only be called during initialization time (i.e. in // NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. The scheme set last overrides // an init() function), and is not thread-safe. The scheme set last overrides
// previously set values. // previously set values.
func SetDefaultScheme(scheme string) { func SetDefaultScheme(scheme string) {
defaultScheme = scheme defaultScheme = scheme
internal.UserSetDefaultScheme = true
} }
// GetDefaultScheme gets the default scheme that will be used. // GetDefaultScheme gets the default scheme that will be used by grpc.Dial. If
// SetDefaultScheme is never called, the default scheme used by grpc.NewClient is "dns" instead.
func GetDefaultScheme() string { func GetDefaultScheme() string {
return defaultScheme return defaultScheme
} }
@ -168,6 +171,9 @@ type BuildOptions struct {
// field. In most cases though, it is not appropriate, and this field may // field. In most cases though, it is not appropriate, and this field may
// be ignored. // be ignored.
Dialer func(context.Context, string) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// Authority is the effective authority of the clientconn for which the
// resolver is built.
Authority string
} }
// An Endpoint is one network endpoint, or server, which may have multiple // An Endpoint is one network endpoint, or server, which may have multiple
@ -281,9 +287,9 @@ func (t Target) Endpoint() string {
return strings.TrimPrefix(endpoint, "/") return strings.TrimPrefix(endpoint, "/")
} }
// String returns a string representation of Target. // String returns the canonical string representation of Target.
func (t Target) String() string { func (t Target) String() string {
return t.URL.String() return t.URL.Scheme + "://" + t.URL.Host + "/" + t.Endpoint()
} }
// Builder creates a resolver that will be used to watch name resolution updates. // Builder creates a resolver that will be used to watch name resolution updates.

View File

@ -66,7 +66,7 @@ func newCCResolverWrapper(cc *ClientConn) *ccResolverWrapper {
// any newly created ccResolverWrapper, except that close may be called instead. // any newly created ccResolverWrapper, except that close may be called instead.
func (ccr *ccResolverWrapper) start() error { func (ccr *ccResolverWrapper) start() error {
errCh := make(chan error) errCh := make(chan error)
ccr.serializer.Schedule(func(ctx context.Context) { ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@ -75,6 +75,7 @@ func (ccr *ccResolverWrapper) start() error {
DialCreds: ccr.cc.dopts.copts.TransportCredentials, DialCreds: ccr.cc.dopts.copts.TransportCredentials,
CredsBundle: ccr.cc.dopts.copts.CredsBundle, CredsBundle: ccr.cc.dopts.copts.CredsBundle,
Dialer: ccr.cc.dopts.copts.Dialer, Dialer: ccr.cc.dopts.copts.Dialer,
Authority: ccr.cc.authority,
} }
var err error var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts) ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
@ -84,7 +85,7 @@ func (ccr *ccResolverWrapper) start() error {
} }
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) { func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.serializer.Schedule(func(ctx context.Context) { ccr.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || ccr.resolver == nil { if ctx.Err() != nil || ccr.resolver == nil {
return return
} }
@ -96,12 +97,12 @@ func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
// finished shutting down, the channel should block on ccr.serializer.Done() // finished shutting down, the channel should block on ccr.serializer.Done()
// without cc.mu held. // without cc.mu held.
func (ccr *ccResolverWrapper) close() { func (ccr *ccResolverWrapper) close() {
channelz.Info(logger, ccr.cc.channelzID, "Closing the name resolver") channelz.Info(logger, ccr.cc.channelz, "Closing the name resolver")
ccr.mu.Lock() ccr.mu.Lock()
ccr.closed = true ccr.closed = true
ccr.mu.Unlock() ccr.mu.Unlock()
ccr.serializer.Schedule(func(context.Context) { ccr.serializer.TrySchedule(func(context.Context) {
if ccr.resolver == nil { if ccr.resolver == nil {
return return
} }
@ -146,7 +147,7 @@ func (ccr *ccResolverWrapper) ReportError(err error) {
return return
} }
ccr.mu.Unlock() ccr.mu.Unlock()
channelz.Warningf(logger, ccr.cc.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) channelz.Warningf(logger, ccr.cc.channelz, "ccResolverWrapper: reporting error to cc: %v", err)
ccr.cc.updateResolverStateAndUnlock(resolver.State{}, err) ccr.cc.updateResolverStateAndUnlock(resolver.State{}, err)
} }
@ -170,12 +171,15 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
// ParseServiceConfig is called by resolver implementations to parse a JSON // ParseServiceConfig is called by resolver implementations to parse a JSON
// representation of the service config. // representation of the service config.
func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult { func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.ParseResult {
return parseServiceConfig(scJSON) return parseServiceConfig(scJSON, ccr.cc.dopts.maxCallAttempts)
} }
// addChannelzTraceEvent adds a channelz trace event containing the new // addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations. // state received from resolver implementations.
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) { func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
if !logger.V(0) && !channelz.IsOn() {
return
}
var updates []string var updates []string
var oldSC, newSC *ServiceConfig var oldSC, newSC *ServiceConfig
var oldOK, newOK bool var oldOK, newOK bool
@ -193,5 +197,5 @@ func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
} else if len(ccr.curState.Addresses) == 0 && len(s.Addresses) > 0 { } else if len(ccr.curState.Addresses) == 0 && len(s.Addresses) > 0 {
updates = append(updates, "resolver returned new addresses") updates = append(updates, "resolver returned new addresses")
} }
channelz.Infof(logger, ccr.cc.channelzID, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; ")) channelz.Infof(logger, ccr.cc.channelz, "Resolver state updated: %s (%v)", pretty.ToJSON(s), strings.Join(updates, "; "))
} }

View File

@ -19,7 +19,6 @@
package grpc package grpc
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"context" "context"
"encoding/binary" "encoding/binary"
@ -35,6 +34,7 @@ import (
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -271,17 +271,13 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
} }
} }
// WaitForReady configures the action to take when an RPC is attempted on broken // WaitForReady configures the RPC's behavior when the client is in
// connections or unreachable servers. If waitForReady is false and the // TRANSIENT_FAILURE, which occurs when all addresses fail to connect. If
// connection is in the TRANSIENT_FAILURE state, the RPC will fail // waitForReady is false, the RPC will fail immediately. Otherwise, the client
// immediately. Otherwise, the RPC client will block the call until a // will wait until a connection becomes available or the RPC's deadline is
// connection is available (or the call is canceled or times out) and will // reached.
// retry the call if it fails due to a transient error. gRPC will not retry if
// data was written to the wire unless the server indicates it did not process
// the data. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/wait-for-ready.md.
// //
// By default, RPCs don't "wait for ready". // By default, RPCs do not "wait for ready".
func WaitForReady(waitForReady bool) CallOption { func WaitForReady(waitForReady bool) CallOption {
return FailFastCallOption{FailFast: !waitForReady} return FailFastCallOption{FailFast: !waitForReady}
} }
@ -515,11 +511,51 @@ type ForceCodecCallOption struct {
} }
func (o ForceCodecCallOption) before(c *callInfo) error { func (o ForceCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec c.codec = newCodecV1Bridge(o.Codec)
return nil return nil
} }
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {} func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
// ForceCodecV2 returns a CallOption that will set codec to be used for all
// request and response messages for a call. The result of calling Name() will
// be used as the content-subtype after converting to lowercase, unless
// CallContentSubtype is also used.
//
// See Content-Type on
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details. Also see the documentation on RegisterCodec and
// CallContentSubtype for more details on the interaction between Codec and
// content-subtype.
//
// This function is provided for advanced users; prefer to use only
// CallContentSubtype to select a registered codec instead.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceCodecV2(codec encoding.CodecV2) CallOption {
return ForceCodecV2CallOption{CodecV2: codec}
}
// ForceCodecV2CallOption is a CallOption that indicates the codec used for
// marshaling messages.
//
// # Experimental
//
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type ForceCodecV2CallOption struct {
CodecV2 encoding.CodecV2
}
func (o ForceCodecV2CallOption) before(c *callInfo) error {
c.codec = o.CodecV2
return nil
}
func (o ForceCodecV2CallOption) after(c *callInfo, attempt *csAttempt) {}
// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of // CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
// an encoding.Codec. // an encoding.Codec.
// //
@ -540,7 +576,7 @@ type CustomCodecCallOption struct {
} }
func (o CustomCodecCallOption) before(c *callInfo) error { func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec c.codec = newCodecV0Bridge(o.Codec)
return nil return nil
} }
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {} func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}
@ -581,19 +617,28 @@ const (
compressionMade payloadFormat = 1 // compressed compressionMade payloadFormat = 1 // compressed
) )
func (pf payloadFormat) isCompressed() bool {
return pf == compressionMade
}
type streamReader interface {
ReadHeader(header []byte) error
Read(n int) (mem.BufferSlice, error)
}
// parser reads complete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
type parser struct { type parser struct {
// r is the underlying reader. // r is the underlying reader.
// See the comment on recvMsg for the permissible // See the comment on recvMsg for the permissible
// error types. // error types.
r io.Reader r streamReader
// The header of a gRPC message. Find more detail at // The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte header [5]byte
// recvBufferPool is the pool of shared receive buffers. // bufferPool is the pool of shared receive buffers.
recvBufferPool SharedBufferPool bufferPool mem.BufferPool
} }
// recvMsg reads a complete gRPC message from the stream. // recvMsg reads a complete gRPC message from the stream.
@ -608,14 +653,15 @@ type parser struct {
// - an error from the status package // - an error from the status package
// //
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying streamReader must not return an incompatible
// error. // error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxReceiveMessageSize int) (payloadFormat, mem.BufferSlice, error) {
if _, err := p.r.Read(p.header[:]); err != nil { err := p.r.ReadHeader(p.header[:])
if err != nil {
return 0, nil, err return 0, nil, err
} }
pf = payloadFormat(p.header[0]) pf := payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:]) length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 { if length == 0 {
@ -627,20 +673,21 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
if int(length) > maxReceiveMessageSize { if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
} }
msg = p.recvBufferPool.Get(int(length))
if _, err := p.r.Read(msg); err != nil { data, err := p.r.Read(int(length))
if err != nil {
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return 0, nil, err return 0, nil, err
} }
return pf, msg, nil return pf, data, nil
} }
// encode serializes msg and returns a buffer containing the message, or an // encode serializes msg and returns a buffer containing the message, or an
// error if it is too large to be transmitted by grpc. If msg is nil, it // error if it is too large to be transmitted by grpc. If msg is nil, it
// generates an empty message. // generates an empty message.
func encode(c baseCodec, msg any) ([]byte, error) { func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
if msg == nil { // NOTE: typed nils will not be caught by this check if msg == nil { // NOTE: typed nils will not be caught by this check
return nil, nil return nil, nil
} }
@ -648,7 +695,8 @@ func encode(c baseCodec, msg any) ([]byte, error) {
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
} }
if uint(len(b)) > math.MaxUint32 { if uint(b.Len()) > math.MaxUint32 {
b.Free()
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
} }
return b, nil return b, nil
@ -659,34 +707,41 @@ func encode(c baseCodec, msg any) ([]byte, error) {
// indicating no compression was done. // indicating no compression was done.
// //
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. // TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) { func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
if compressor == nil && cp == nil { if (compressor == nil && cp == nil) || in.Len() == 0 {
return nil, nil return nil, compressionNone, nil
}
if len(in) == 0 {
return nil, nil
} }
var out mem.BufferSlice
w := mem.NewWriter(&out, pool)
wrapErr := func(err error) error { wrapErr := func(err error) error {
out.Free()
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
} }
cbuf := &bytes.Buffer{}
if compressor != nil { if compressor != nil {
z, err := compressor.Compress(cbuf) z, err := compressor.Compress(w)
if err != nil { if err != nil {
return nil, wrapErr(err) return nil, 0, wrapErr(err)
}
for _, b := range in {
if _, err := z.Write(b.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
} }
if _, err := z.Write(in); err != nil {
return nil, wrapErr(err)
} }
if err := z.Close(); err != nil { if err := z.Close(); err != nil {
return nil, wrapErr(err) return nil, 0, wrapErr(err)
} }
} else { } else {
if err := cp.Do(cbuf, in); err != nil { // This is obviously really inefficient since it fully materializes the data, but
return nil, wrapErr(err) // there is no way around this with the old Compressor API. At least it attempts
// to return the buffer to the provider, in the hopes it can be reused (maybe
// even by a subsequent call to this very function).
buf := in.MaterializeToBuffer(pool)
defer buf.Free()
if err := cp.Do(w, buf.ReadOnlyData()); err != nil {
return nil, 0, wrapErr(err)
} }
} }
return cbuf.Bytes(), nil return out, compressionMade, nil
} }
const ( const (
@ -697,33 +752,36 @@ const (
// msgHeader returns a 5-byte header for the message being transmitted and the // msgHeader returns a 5-byte header for the message being transmitted and the
// payload, which is compData if non-nil or data otherwise. // payload, which is compData if non-nil or data otherwise.
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) { func msgHeader(data, compData mem.BufferSlice, pf payloadFormat) (hdr []byte, payload mem.BufferSlice) {
hdr = make([]byte, headerLen) hdr = make([]byte, headerLen)
if compData != nil { hdr[0] = byte(pf)
hdr[0] = byte(compressionMade)
data = compData var length uint32
if pf.isCompressed() {
length = uint32(compData.Len())
payload = compData
} else { } else {
hdr[0] = byte(compressionNone) length = uint32(data.Len())
payload = data
} }
// Write length of payload into buf // Write length of payload into buf
binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data))) binary.BigEndian.PutUint32(hdr[payloadLen:], length)
return hdr, data return hdr, payload
} }
func outPayload(client bool, msg any, data, payload []byte, t time.Time) *stats.OutPayload { func outPayload(client bool, msg any, dataLength, payloadLength int, t time.Time) *stats.OutPayload {
return &stats.OutPayload{ return &stats.OutPayload{
Client: client, Client: client,
Payload: msg, Payload: msg,
Data: data, Length: dataLength,
Length: len(data), WireLength: payloadLength + headerLen,
WireLength: len(payload) + headerLen, CompressedLength: payloadLength,
CompressedLength: len(payload),
SentTime: t, SentTime: t,
} }
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool, isServer bool) *status.Status {
switch pf { switch pf {
case compressionNone: case compressionNone:
case compressionMade: case compressionMade:
@ -731,7 +789,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
} }
if !haveCompressor { if !haveCompressor {
if isServer {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} else {
return status.Newf(codes.Internal, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
} }
default: default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
@ -741,88 +803,129 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
type payloadInfo struct { type payloadInfo struct {
compressedLength int // The compressed length got from wire. compressedLength int // The compressed length got from wire.
uncompressedBytes []byte uncompressedBytes mem.BufferSlice
} }
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) { func (p *payloadInfo) free() {
pf, buf, err := p.recvMsg(maxReceiveMessageSize) if p != nil && p.uncompressedBytes != nil {
p.uncompressedBytes.Free()
}
}
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if payInfo != nil {
payInfo.compressedLength = len(buf)
}
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { compressedLength := compressed.Len()
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil, isServer); st != nil {
compressed.Free()
return nil, st.Err() return nil, st.Err()
} }
var size int var size int
if pf == compressionMade { if pf.isCompressed() {
defer compressed.Free()
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default. // use this decompressor as the default.
if dc != nil { if dc != nil {
buf, err = dc.Do(bytes.NewReader(buf)) var uncompressedBuf []byte
size = len(buf) uncompressedBuf, err = dc.Do(compressed.Reader())
if err == nil {
out = mem.BufferSlice{mem.NewBuffer(&uncompressedBuf, nil)}
}
size = len(uncompressedBuf)
} else { } else {
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize) out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
} }
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
} }
if size > maxReceiveMessageSize { if size > maxReceiveMessageSize {
out.Free()
// TODO: Revisit the error code. Currently keep it consistent with java // TODO: Revisit the error code. Currently keep it consistent with java
// implementation. // implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
} }
} else {
out = compressed
} }
return buf, nil
if payInfo != nil {
payInfo.compressedLength = compressedLength
out.Ref()
payInfo.uncompressedBytes = out
}
return out, nil
} }
// Using compressor, decompress d, returning data and size. // Using compressor, decompress d, returning data and size.
// Optionally, if data will be over maxReceiveMessageSize, just return the size. // Optionally, if data will be over maxReceiveMessageSize, just return the size.
func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize int) ([]byte, int, error) { func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
dcReader, err := compressor.Decompress(bytes.NewReader(d)) dcReader, err := compressor.Decompress(d.Reader())
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if sizer, ok := compressor.(interface {
DecompressedSize(compressedBytes []byte) int // TODO: Can/should this still be preserved with the new BufferSlice API? Are
}); ok { // there any actual benefits to allocating a single large buffer instead of
if size := sizer.DecompressedSize(d); size >= 0 { // multiple smaller ones?
if size > maxReceiveMessageSize { //if sizer, ok := compressor.(interface {
return nil, size, nil // DecompressedSize(compressedBytes []byte) int
//}); ok {
// if size := sizer.DecompressedSize(d); size >= 0 {
// if size > maxReceiveMessageSize {
// return nil, size, nil
// }
// // size is used as an estimate to size the buffer, but we
// // will read more data if available.
// // +MinRead so ReadFrom will not reallocate if size is correct.
// //
// // TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// // we can also utilize the recv buffer pool here.
// buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
// bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
// return buf.Bytes(), int(bytesRead), err
// }
//}
var out mem.BufferSlice
_, err = io.Copy(mem.NewWriter(&out, pool), io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
if err != nil {
out.Free()
return nil, 0, err
} }
// size is used as an estimate to size the buffer, but we return out, out.Len(), nil
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
}
}
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
d, err = io.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return d, len(d), err
} }
// For the two compressor parameters, both should not be set, but if they are, // For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor. // dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API? // TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil { if err != nil {
return err return err
} }
if err := c.Unmarshal(buf, m); err != nil {
// If the codec wants its own reference to the data, it can get it. Otherwise, always
// free the buffers.
defer data.Free()
if err := c.Unmarshal(data, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err) return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
} }
if payInfo != nil {
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil return nil
} }
@ -925,7 +1028,7 @@ func setCallInfoCodec(c *callInfo) error {
// encoding.Codec (Name vs. String method name). We only support // encoding.Codec (Name vs. String method name). We only support
// setting content subtype from encoding.Codec to avoid a behavior // setting content subtype from encoding.Codec to avoid a behavior
// change with the deprecated version. // change with the deprecated version.
if ec, ok := c.codec.(encoding.Codec); ok { if ec, ok := c.codec.(encoding.CodecV2); ok {
c.contentSubtype = strings.ToLower(ec.Name()) c.contentSubtype = strings.ToLower(ec.Name())
} }
} }
@ -934,34 +1037,21 @@ func setCallInfoCodec(c *callInfo) error {
if c.contentSubtype == "" { if c.contentSubtype == "" {
// No codec specified in CallOptions; use proto by default. // No codec specified in CallOptions; use proto by default.
c.codec = encoding.GetCodec(proto.Name) c.codec = getCodec(proto.Name)
return nil return nil
} }
// c.contentSubtype is already lowercased in CallContentSubtype // c.contentSubtype is already lowercased in CallContentSubtype
c.codec = encoding.GetCodec(c.contentSubtype) c.codec = getCodec(c.contentSubtype)
if c.codec == nil { if c.codec == nil {
return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype) return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype)
} }
return nil return nil
} }
// channelzData is used to store channelz related data for ClientConn, addrConn and Server.
// These fields cannot be embedded in the original structs (e.g. ClientConn), since to do atomic
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment.
type channelzData struct {
callsStarted int64
callsFailed int64
callsSucceeded int64
// lastCallStartedTime stores the timestamp that last call starts. It is of int64 type instead of
// time.Time since it's more costly to atomically update time.Time variable than int64 variable.
lastCallStartedTime int64
}
// The SupportPackageIsVersion variables are referenced from generated protocol // The SupportPackageIsVersion variables are referenced from generated protocol
// buffer files to ensure compatibility with the gRPC version used. The latest // buffer files to ensure compatibility with the gRPC version used. The latest
// support package version is 7. // support package version is 9.
// //
// Older versions are kept for compatibility. // Older versions are kept for compatibility.
// //
@ -973,6 +1063,7 @@ const (
SupportPackageIsVersion6 = true SupportPackageIsVersion6 = true
SupportPackageIsVersion7 = true SupportPackageIsVersion7 = true
SupportPackageIsVersion8 = true SupportPackageIsVersion8 = true
SupportPackageIsVersion9 = true
) )
const grpcUA = "grpc-go/" + Version const grpcUA = "grpc-go/" + Version

View File

@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -80,7 +81,7 @@ func init() {
} }
internal.BinaryLogger = binaryLogger internal.BinaryLogger = binaryLogger
internal.JoinServerOptions = newJoinServerOption internal.JoinServerOptions = newJoinServerOption
internal.RecvBufferPool = recvBufferPool internal.BufferPool = bufferPool
} }
var statusOK = status.New(codes.OK, "") var statusOK = status.New(codes.OK, "")
@ -137,8 +138,7 @@ type Server struct {
serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop serveWG sync.WaitGroup // counts active Serve goroutines for Stop/GracefulStop
handlersWG sync.WaitGroup // counts active method handler goroutines handlersWG sync.WaitGroup // counts active method handler goroutines
channelzID *channelz.Identifier channelz *channelz.Server
czData *channelzData
serverWorkerChannel chan func() serverWorkerChannel chan func()
serverWorkerChannelClose func() serverWorkerChannelClose func()
@ -171,7 +171,7 @@ type serverOptions struct {
maxHeaderListSize *uint32 maxHeaderListSize *uint32
headerTableSize *uint32 headerTableSize *uint32
numServerWorkers uint32 numServerWorkers uint32
recvBufferPool SharedBufferPool bufferPool mem.BufferPool
waitForHandlers bool waitForHandlers bool
} }
@ -182,7 +182,7 @@ var defaultServerOptions = serverOptions{
connectionTimeout: 120 * time.Second, connectionTimeout: 120 * time.Second,
writeBufferSize: defaultWriteBufSize, writeBufferSize: defaultWriteBufSize,
readBufferSize: defaultReadBufSize, readBufferSize: defaultReadBufSize,
recvBufferPool: nopBufferPool{}, bufferPool: mem.DefaultBufferPool(),
} }
var globalServerOptions []ServerOption var globalServerOptions []ServerOption
@ -249,11 +249,9 @@ func SharedWriteBuffer(val bool) ServerOption {
} }
// WriteBufferSize determines how much data can be batched before doing a write // WriteBufferSize determines how much data can be batched before doing a write
// on the wire. The corresponding memory allocation for this buffer will be // on the wire. The default value for this buffer is 32KB. Zero or negative
// twice the size to keep syscalls low. The default value for this buffer is // values will disable the write buffer such that each write will be on underlying
// 32KB. Zero or negative values will disable the write buffer such that each // connection. Note: A Send call may not directly translate to a write.
// write will be on underlying connection.
// Note: A Send call may not directly translate to a write.
func WriteBufferSize(s int) ServerOption { func WriteBufferSize(s int) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.writeBufferSize = s o.writeBufferSize = s
@ -316,7 +314,7 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
// Will be supported throughout 1.x. // Will be supported throughout 1.x.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.codec = codec o.codec = newCodecV0Bridge(codec)
}) })
} }
@ -345,7 +343,22 @@ func CustomCodec(codec Codec) ServerOption {
// later release. // later release.
func ForceServerCodec(codec encoding.Codec) ServerOption { func ForceServerCodec(codec encoding.Codec) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.codec = codec o.codec = newCodecV1Bridge(codec)
})
}
// ForceServerCodecV2 is the equivalent of ForceServerCodec, but for the new
// CodecV2 interface.
//
// Will be supported throughout 1.x.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func ForceServerCodecV2(codecV2 encoding.CodecV2) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.codec = codecV2
}) })
} }
@ -530,12 +543,22 @@ func ConnectionTimeout(d time.Duration) ServerOption {
}) })
} }
// MaxHeaderListSizeServerOption is a ServerOption that sets the max
// (uncompressed) size of header list that the server is prepared to accept.
type MaxHeaderListSizeServerOption struct {
MaxHeaderListSize uint32
}
func (o MaxHeaderListSizeServerOption) apply(so *serverOptions) {
so.maxHeaderListSize = &o.MaxHeaderListSize
}
// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size // MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
// of header list that the server is prepared to accept. // of header list that the server is prepared to accept.
func MaxHeaderListSize(s uint32) ServerOption { func MaxHeaderListSize(s uint32) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return MaxHeaderListSizeServerOption{
o.maxHeaderListSize = &s MaxHeaderListSize: s,
}) }
} }
// HeaderTableSize returns a ServerOption that sets the size of dynamic // HeaderTableSize returns a ServerOption that sets the size of dynamic
@ -585,26 +608,9 @@ func WaitForHandlers(w bool) ServerOption {
}) })
} }
// RecvBufferPool returns a ServerOption that configures the server func bufferPool(bufferPool mem.BufferPool) ServerOption {
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: StatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// Deprecated: use experimental.WithRecvBufferPool instead. Will be deleted in
// v1.60.0 or later.
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
return recvBufferPool(bufferPool)
}
func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
return newFuncServerOption(func(o *serverOptions) { return newFuncServerOption(func(o *serverOptions) {
o.recvBufferPool = bufferPool o.bufferPool = bufferPool
}) })
} }
@ -615,7 +621,7 @@ func recvBufferPool(bufferPool SharedBufferPool) ServerOption {
// workload (assuming a QPS of a few thousand requests/sec). // workload (assuming a QPS of a few thousand requests/sec).
const serverWorkerResetThreshold = 1 << 16 const serverWorkerResetThreshold = 1 << 16
// serverWorkers blocks on a *transport.Stream channel forever and waits for // serverWorker blocks on a *transport.Stream channel forever and waits for
// data to be fed by serveStreams. This allows multiple requests to be // data to be fed by serveStreams. This allows multiple requests to be
// processed by the same goroutine, removing the need for expensive stack // processed by the same goroutine, removing the need for expensive stack
// re-allocations (see the runtime.morestack problem [1]). // re-allocations (see the runtime.morestack problem [1]).
@ -661,7 +667,7 @@ func NewServer(opt ...ServerOption) *Server {
services: make(map[string]*serviceInfo), services: make(map[string]*serviceInfo),
quit: grpcsync.NewEvent(), quit: grpcsync.NewEvent(),
done: grpcsync.NewEvent(), done: grpcsync.NewEvent(),
czData: new(channelzData), channelz: channelz.RegisterServer(""),
} }
chainUnaryServerInterceptors(s) chainUnaryServerInterceptors(s)
chainStreamServerInterceptors(s) chainStreamServerInterceptors(s)
@ -675,8 +681,7 @@ func NewServer(opt ...ServerOption) *Server {
s.initServerWorkers() s.initServerWorkers()
} }
s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") channelz.Info(logger, s.channelz, "Server created")
channelz.Info(logger, s.channelzID, "Server created")
return s return s
} }
@ -802,20 +807,13 @@ var ErrServerStopped = errors.New("grpc: the server has been stopped")
type listenSocket struct { type listenSocket struct {
net.Listener net.Listener
channelzID *channelz.Identifier channelz *channelz.Socket
}
func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric {
return &channelz.SocketInternalMetric{
SocketOptions: channelz.GetSocketOption(l.Listener),
LocalAddr: l.Listener.Addr(),
}
} }
func (l *listenSocket) Close() error { func (l *listenSocket) Close() error {
err := l.Listener.Close() err := l.Listener.Close()
channelz.RemoveEntry(l.channelzID) channelz.RemoveEntry(l.channelz.ID)
channelz.Info(logger, l.channelzID, "ListenSocket deleted") channelz.Info(logger, l.channelz, "ListenSocket deleted")
return err return err
} }
@ -857,7 +855,16 @@ func (s *Server) Serve(lis net.Listener) error {
} }
}() }()
ls := &listenSocket{Listener: lis} ls := &listenSocket{
Listener: lis,
channelz: channelz.RegisterSocket(&channelz.Socket{
SocketType: channelz.SocketTypeListen,
Parent: s.channelz,
RefName: lis.Addr().String(),
LocalAddr: lis.Addr(),
SocketOptions: channelz.GetSocketOption(lis)},
),
}
s.lis[ls] = true s.lis[ls] = true
defer func() { defer func() {
@ -869,14 +876,8 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Unlock() s.mu.Unlock()
}() }()
var err error
ls.channelzID, err = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String())
if err != nil {
s.mu.Unlock() s.mu.Unlock()
return err channelz.Info(logger, ls.channelz, "ListenSocket created")
}
s.mu.Unlock()
channelz.Info(logger, ls.channelzID, "ListenSocket created")
var tempDelay time.Duration // how long to sleep on accept failure var tempDelay time.Duration // how long to sleep on accept failure
for { for {
@ -975,9 +976,10 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
WriteBufferSize: s.opts.writeBufferSize, WriteBufferSize: s.opts.writeBufferSize,
ReadBufferSize: s.opts.readBufferSize, ReadBufferSize: s.opts.readBufferSize,
SharedWriteBuffer: s.opts.sharedWriteBuffer, SharedWriteBuffer: s.opts.sharedWriteBuffer,
ChannelzParentID: s.channelzID, ChannelzParent: s.channelz,
MaxHeaderListSize: s.opts.maxHeaderListSize, MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize, HeaderTableSize: s.opts.headerTableSize,
BufferPool: s.opts.bufferPool,
} }
st, err := transport.NewServerTransport(c, config) st, err := transport.NewServerTransport(c, config)
if err != nil { if err != nil {
@ -989,7 +991,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
if err != credentials.ErrConnDispatched { if err != credentials.ErrConnDispatched {
// Don't log on ErrConnDispatched and io.EOF to prevent log spam. // Don't log on ErrConnDispatched and io.EOF to prevent log spam.
if err != io.EOF { if err != io.EOF {
channelz.Info(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) channelz.Info(logger, s.channelz, "grpc: Server.Serve failed to create ServerTransport: ", err)
} }
c.Close() c.Close()
} }
@ -1070,7 +1072,7 @@ var _ http.Handler = (*Server)(nil)
// Notice: This API is EXPERIMENTAL and may be changed or removed in a // Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release. // later release.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool)
if err != nil { if err != nil {
// Errors returned from transport.NewServerHandlerTransport have // Errors returned from transport.NewServerHandlerTransport have
// already been written to w. // already been written to w.
@ -1121,48 +1123,54 @@ func (s *Server) removeConn(addr string, st transport.ServerTransport) {
} }
} }
func (s *Server) channelzMetric() *channelz.ServerInternalMetric {
return &channelz.ServerInternalMetric{
CallsStarted: atomic.LoadInt64(&s.czData.callsStarted),
CallsSucceeded: atomic.LoadInt64(&s.czData.callsSucceeded),
CallsFailed: atomic.LoadInt64(&s.czData.callsFailed),
LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&s.czData.lastCallStartedTime)),
}
}
func (s *Server) incrCallsStarted() { func (s *Server) incrCallsStarted() {
atomic.AddInt64(&s.czData.callsStarted, 1) s.channelz.ServerMetrics.CallsStarted.Add(1)
atomic.StoreInt64(&s.czData.lastCallStartedTime, time.Now().UnixNano()) s.channelz.ServerMetrics.LastCallStartedTimestamp.Store(time.Now().UnixNano())
} }
func (s *Server) incrCallsSucceeded() { func (s *Server) incrCallsSucceeded() {
atomic.AddInt64(&s.czData.callsSucceeded, 1) s.channelz.ServerMetrics.CallsSucceeded.Add(1)
} }
func (s *Server) incrCallsFailed() { func (s *Server) incrCallsFailed() {
atomic.AddInt64(&s.czData.callsFailed, 1) s.channelz.ServerMetrics.CallsFailed.Add(1)
} }
func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg) data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil { if err != nil {
channelz.Error(logger, s.channelzID, "grpc: server failed to encode response: ", err) channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
return err return err
} }
compData, err := compress(data, cp, comp)
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
if err != nil { if err != nil {
channelz.Error(logger, s.channelzID, "grpc: server failed to compress response: ", err) data.Free()
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
return err return err
} }
hdr, payload := msgHeader(data, compData)
hdr, payload := msgHeader(data, compData, pf)
defer func() {
compData.Free()
data.Free()
// payload does not need to be freed here, it is either data or compData, both of
// which are already freed.
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > s.opts.maxSendMessageSize { if payloadLen > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize)
} }
err = t.Write(stream, hdr, payload, opts) err = t.Write(stream, hdr, payload, opts)
if err == nil { if err == nil {
if len(s.opts.statsHandlers) != 0 {
for _, sh := range s.opts.statsHandlers { for _, sh := range s.opts.statsHandlers {
sh.HandleRPC(ctx, outPayload(false, msg, data, payload, time.Now())) sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now()))
}
} }
} }
return err return err
@ -1341,14 +1349,17 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
var payInfo *payloadInfo var payInfo *payloadInfo
if len(shs) != 0 || len(binlogs) != 0 { if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, err := recvAndDecompress(&parser{r: stream, bufferPool: s.opts.bufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp, true)
if err != nil { if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil { if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
} }
return err return err
} }
defer d.Free()
if channelz.IsOn() { if channelz.IsOn() {
t.IncrMsgRecv() t.IncrMsgRecv()
} }
@ -1356,19 +1367,19 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
} }
for _, sh := range shs { for _, sh := range shs {
sh.HandleRPC(ctx, &stats.InPayload{ sh.HandleRPC(ctx, &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: v, Payload: v,
Length: len(d), Length: d.Len(),
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
Data: d,
}) })
} }
if len(binlogs) != 0 { if len(binlogs) != 0 {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
Message: d, Message: d.Materialize(),
} }
for _, binlog := range binlogs { for _, binlog := range binlogs {
binlog.Log(ctx, cm) binlog.Log(ctx, cm)
@ -1394,7 +1405,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
trInfo.tr.SetError() trInfo.tr.SetError()
} }
if e := t.WriteStatus(stream, appStatus); e != nil { if e := t.WriteStatus(stream, appStatus); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
} }
if len(binlogs) != 0 { if len(binlogs) != 0 {
if h, _ := stream.Header(); h.Len() > 0 { if h, _ := stream.Header(); h.Len() > 0 {
@ -1434,7 +1445,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
} }
if sts, ok := status.FromError(err); ok { if sts, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, sts); e != nil { if e := t.WriteStatus(stream, sts); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) channelz.Warningf(logger, s.channelz, "grpc: Server.processUnaryRPC failed to write status: %v", e)
} }
} else { } else {
switch st := err.(type) { switch st := err.(type) {
@ -1552,7 +1563,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
ctx: ctx, ctx: ctx,
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, p: &parser{r: stream, bufferPool: s.opts.bufferPool},
codec: s.getCodec(stream.ContentSubtype()), codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,
@ -1762,7 +1773,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError() ti.tr.SetError()
} }
channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
} }
if ti != nil { if ti != nil {
ti.tr.Finish() ti.tr.Finish()
@ -1819,7 +1830,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true)
ti.tr.SetError() ti.tr.SetError()
} }
channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err)
} }
if ti != nil { if ti != nil {
ti.tr.Finish() ti.tr.Finish()
@ -1891,8 +1902,7 @@ func (s *Server) stop(graceful bool) {
s.quit.Fire() s.quit.Fire()
defer s.done.Fire() defer s.done.Fire()
s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelz.ID) })
s.mu.Lock() s.mu.Lock()
s.closeListenersLocked() s.closeListenersLocked()
// Wait for serving threads to be ready to exit. Only then can we be sure no // Wait for serving threads to be ready to exit. Only then can we be sure no
@ -1968,12 +1978,12 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
return s.opts.codec return s.opts.codec
} }
if contentSubtype == "" { if contentSubtype == "" {
return encoding.GetCodec(proto.Name) return getCodec(proto.Name)
} }
codec := encoding.GetCodec(contentSubtype) codec := getCodec(contentSubtype)
if codec == nil { if codec == nil {
logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name) logger.Warningf("Unsupported codec %q. Defaulting to %q for now. This will start to fail in future releases.", contentSubtype, proto.Name)
return encoding.GetCodec(proto.Name) return getCodec(proto.Name)
} }
return codec return codec
} }
@ -2117,7 +2127,7 @@ func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
} }
return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil return stream.ClientAdvertisedCompressors(), nil
} }
// SetTrailer sets the trailer metadata that will be sent when an RPC returns. // SetTrailer sets the trailer metadata that will be sent when an RPC returns.
@ -2147,17 +2157,9 @@ func Method(ctx context.Context) (string, bool) {
return s.Method(), true return s.Method(), true
} }
type channelzServer struct {
s *Server
}
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}
// validateSendCompressor returns an error when given compressor name cannot be // validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors. // handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error { func validateSendCompressor(name string, clientCompressors []string) error {
if name == encoding.Identity { if name == encoding.Identity {
return nil return nil
} }
@ -2166,7 +2168,7 @@ func validateSendCompressor(name, clientCompressors string) error {
return fmt.Errorf("compressor not registered %q", name) return fmt.Errorf("compressor not registered %q", name)
} }
for _, c := range strings.Split(clientCompressors, ",") { for _, c := range clientCompressors {
if c == name { if c == name {
return nil // found match return nil // found match
} }

View File

@ -25,8 +25,11 @@ import (
"reflect" "reflect"
"time" "time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
) )
@ -41,11 +44,6 @@ const maxInt = int(^uint(0) >> 1)
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
type MethodConfig = internalserviceconfig.MethodConfig type MethodConfig = internalserviceconfig.MethodConfig
type lbConfig struct {
name string
cfg serviceconfig.LoadBalancingConfig
}
// ServiceConfig is provided by the service provider and contains parameters for how // ServiceConfig is provided by the service provider and contains parameters for how
// clients that connect to the service should behave. // clients that connect to the service should behave.
// //
@ -55,14 +53,9 @@ type lbConfig struct {
type ServiceConfig struct { type ServiceConfig struct {
serviceconfig.Config serviceconfig.Config
// LB is the load balancer the service providers recommends. This is
// deprecated; lbConfigs is preferred. If lbConfig and LB are both present,
// lbConfig will be used.
LB *string
// lbConfig is the service config's load balancing configuration. If // lbConfig is the service config's load balancing configuration. If
// lbConfig and LB are both present, lbConfig will be used. // lbConfig and LB are both present, lbConfig will be used.
lbConfig *lbConfig lbConfig serviceconfig.LoadBalancingConfig
// Methods contains a map for the methods in this service. If there is an // Methods contains a map for the methods in this service. If there is an
// exact match for a method (i.e. /service/method) in the map, use the // exact match for a method (i.e. /service/method) in the map, use the
@ -164,38 +157,55 @@ type jsonMC struct {
// TODO(lyuxuan): delete this struct after cleaning up old service config implementation. // TODO(lyuxuan): delete this struct after cleaning up old service config implementation.
type jsonSC struct { type jsonSC struct {
LoadBalancingPolicy *string LoadBalancingPolicy *string
LoadBalancingConfig *internalserviceconfig.BalancerConfig LoadBalancingConfig *json.RawMessage
MethodConfig *[]jsonMC MethodConfig *[]jsonMC
RetryThrottling *retryThrottlingPolicy RetryThrottling *retryThrottlingPolicy
HealthCheckConfig *healthCheckConfig HealthCheckConfig *healthCheckConfig
} }
func init() { func init() {
internal.ParseServiceConfig = parseServiceConfig internal.ParseServiceConfig = func(js string) *serviceconfig.ParseResult {
return parseServiceConfig(js, defaultMaxCallAttempts)
}
} }
func parseServiceConfig(js string) *serviceconfig.ParseResult { func parseServiceConfig(js string, maxAttempts int) *serviceconfig.ParseResult {
if len(js) == 0 { if len(js) == 0 {
return &serviceconfig.ParseResult{Err: fmt.Errorf("no JSON service config provided")} return &serviceconfig.ParseResult{Err: fmt.Errorf("no JSON service config provided")}
} }
var rsc jsonSC var rsc jsonSC
err := json.Unmarshal([]byte(js), &rsc) err := json.Unmarshal([]byte(js), &rsc)
if err != nil { if err != nil {
logger.Warningf("grpc: unmarshaling service config %s: %v", js, err) logger.Warningf("grpc: unmarshalling service config %s: %v", js, err)
return &serviceconfig.ParseResult{Err: err} return &serviceconfig.ParseResult{Err: err}
} }
sc := ServiceConfig{ sc := ServiceConfig{
LB: rsc.LoadBalancingPolicy,
Methods: make(map[string]MethodConfig), Methods: make(map[string]MethodConfig),
retryThrottling: rsc.RetryThrottling, retryThrottling: rsc.RetryThrottling,
healthCheckConfig: rsc.HealthCheckConfig, healthCheckConfig: rsc.HealthCheckConfig,
rawJSONString: js, rawJSONString: js,
} }
if c := rsc.LoadBalancingConfig; c != nil { c := rsc.LoadBalancingConfig
sc.lbConfig = &lbConfig{ if c == nil {
name: c.Name, name := pickfirst.Name
cfg: c.Config, if rsc.LoadBalancingPolicy != nil {
name = *rsc.LoadBalancingPolicy
} }
if balancer.Get(name) == nil {
name = pickfirst.Name
} }
cfg := []map[string]any{{name: struct{}{}}}
strCfg, err := json.Marshal(cfg)
if err != nil {
return &serviceconfig.ParseResult{Err: fmt.Errorf("unexpected error marshaling simple LB config: %w", err)}
}
r := json.RawMessage(strCfg)
c = &r
}
cfg, err := gracefulswitch.ParseConfig(*c)
if err != nil {
return &serviceconfig.ParseResult{Err: err}
}
sc.lbConfig = cfg
if rsc.MethodConfig == nil { if rsc.MethodConfig == nil {
return &serviceconfig.ParseResult{Config: &sc} return &serviceconfig.ParseResult{Config: &sc}
@ -211,8 +221,8 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
WaitForReady: m.WaitForReady, WaitForReady: m.WaitForReady,
Timeout: (*time.Duration)(m.Timeout), Timeout: (*time.Duration)(m.Timeout),
} }
if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { if mc.RetryPolicy, err = convertRetryPolicy(m.RetryPolicy, maxAttempts); err != nil {
logger.Warningf("grpc: unmarshaling service config %s: %v", js, err) logger.Warningf("grpc: unmarshalling service config %s: %v", js, err)
return &serviceconfig.ParseResult{Err: err} return &serviceconfig.ParseResult{Err: err}
} }
if m.MaxRequestMessageBytes != nil { if m.MaxRequestMessageBytes != nil {
@ -232,13 +242,13 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
for i, n := range *m.Name { for i, n := range *m.Name {
path, err := n.generatePath() path, err := n.generatePath()
if err != nil { if err != nil {
logger.Warningf("grpc: error unmarshaling service config %s due to methodConfig[%d]: %v", js, i, err) logger.Warningf("grpc: error unmarshalling service config %s due to methodConfig[%d]: %v", js, i, err)
return &serviceconfig.ParseResult{Err: err} return &serviceconfig.ParseResult{Err: err}
} }
if _, ok := paths[path]; ok { if _, ok := paths[path]; ok {
err = errDuplicatedName err = errDuplicatedName
logger.Warningf("grpc: error unmarshaling service config %s due to methodConfig[%d]: %v", js, i, err) logger.Warningf("grpc: error unmarshalling service config %s due to methodConfig[%d]: %v", js, i, err)
return &serviceconfig.ParseResult{Err: err} return &serviceconfig.ParseResult{Err: err}
} }
paths[path] = struct{}{} paths[path] = struct{}{}
@ -257,7 +267,7 @@ func parseServiceConfig(js string) *serviceconfig.ParseResult {
return &serviceconfig.ParseResult{Config: &sc} return &serviceconfig.ParseResult{Config: &sc}
} }
func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPolicy, err error) { func convertRetryPolicy(jrp *jsonRetryPolicy, maxAttempts int) (p *internalserviceconfig.RetryPolicy, err error) {
if jrp == nil { if jrp == nil {
return nil, nil return nil, nil
} }
@ -271,17 +281,16 @@ func convertRetryPolicy(jrp *jsonRetryPolicy) (p *internalserviceconfig.RetryPol
return nil, nil return nil, nil
} }
if jrp.MaxAttempts < maxAttempts {
maxAttempts = jrp.MaxAttempts
}
rp := &internalserviceconfig.RetryPolicy{ rp := &internalserviceconfig.RetryPolicy{
MaxAttempts: jrp.MaxAttempts, MaxAttempts: maxAttempts,
InitialBackoff: time.Duration(jrp.InitialBackoff), InitialBackoff: time.Duration(jrp.InitialBackoff),
MaxBackoff: time.Duration(jrp.MaxBackoff), MaxBackoff: time.Duration(jrp.MaxBackoff),
BackoffMultiplier: jrp.BackoffMultiplier, BackoffMultiplier: jrp.BackoffMultiplier,
RetryableStatusCodes: make(map[codes.Code]bool), RetryableStatusCodes: make(map[codes.Code]bool),
} }
if rp.MaxAttempts > 5 {
// TODO(retry): Make the max maxAttempts configurable.
rp.MaxAttempts = 5
}
for _, code := range jrp.RetryableStatusCodes { for _, code := range jrp.RetryableStatusCodes {
rp.RetryableStatusCodes[code] = true rp.RetryableStatusCodes[code] = true
} }

View File

@ -1,154 +0,0 @@
/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import "sync"
// SharedBufferPool is a pool of buffers that can be shared, resulting in
// decreased memory allocation. Currently, in gRPC-go, it is only utilized
// for parsing incoming messages.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
type SharedBufferPool interface {
// Get returns a buffer with specified length from the pool.
//
// The returned byte slice may be not zero initialized.
Get(length int) []byte
// Put returns a buffer to the pool.
Put(*[]byte)
}
// NewSharedBufferPool creates a simple SharedBufferPool with buckets
// of different sizes to optimize memory usage. This prevents the pool from
// wasting large amounts of memory, even when handling messages of varying sizes.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func NewSharedBufferPool() SharedBufferPool {
return &simpleSharedBufferPool{
pools: [poolArraySize]simpleSharedBufferChildPool{
newBytesPool(level0PoolMaxSize),
newBytesPool(level1PoolMaxSize),
newBytesPool(level2PoolMaxSize),
newBytesPool(level3PoolMaxSize),
newBytesPool(level4PoolMaxSize),
newBytesPool(0),
},
}
}
// simpleSharedBufferPool is a simple implementation of SharedBufferPool.
type simpleSharedBufferPool struct {
pools [poolArraySize]simpleSharedBufferChildPool
}
func (p *simpleSharedBufferPool) Get(size int) []byte {
return p.pools[p.poolIdx(size)].Get(size)
}
func (p *simpleSharedBufferPool) Put(bs *[]byte) {
p.pools[p.poolIdx(cap(*bs))].Put(bs)
}
func (p *simpleSharedBufferPool) poolIdx(size int) int {
switch {
case size <= level0PoolMaxSize:
return level0PoolIdx
case size <= level1PoolMaxSize:
return level1PoolIdx
case size <= level2PoolMaxSize:
return level2PoolIdx
case size <= level3PoolMaxSize:
return level3PoolIdx
case size <= level4PoolMaxSize:
return level4PoolIdx
default:
return levelMaxPoolIdx
}
}
const (
level0PoolMaxSize = 16 // 16 B
level1PoolMaxSize = level0PoolMaxSize * 16 // 256 B
level2PoolMaxSize = level1PoolMaxSize * 16 // 4 KB
level3PoolMaxSize = level2PoolMaxSize * 16 // 64 KB
level4PoolMaxSize = level3PoolMaxSize * 16 // 1 MB
)
const (
level0PoolIdx = iota
level1PoolIdx
level2PoolIdx
level3PoolIdx
level4PoolIdx
levelMaxPoolIdx
poolArraySize
)
type simpleSharedBufferChildPool interface {
Get(size int) []byte
Put(any)
}
type bufferPool struct {
sync.Pool
defaultSize int
}
func (p *bufferPool) Get(size int) []byte {
bs := p.Pool.Get().(*[]byte)
if cap(*bs) < size {
p.Pool.Put(bs)
return make([]byte, size)
}
return (*bs)[:size]
}
func newBytesPool(size int) simpleSharedBufferChildPool {
return &bufferPool{
Pool: sync.Pool{
New: func() any {
bs := make([]byte, size)
return &bs
},
},
defaultSize: size,
}
}
// nopBufferPool is a buffer pool just makes new buffer without pooling.
type nopBufferPool struct {
}
func (nopBufferPool) Get(length int) []byte {
return make([]byte, length)
}
func (nopBufferPool) Put(*[]byte) {
}

View File

@ -73,10 +73,10 @@ func (*PickerUpdated) isRPCStats() {}
type InPayload struct { type InPayload struct {
// Client is true if this InPayload is from client side. // Client is true if this InPayload is from client side.
Client bool Client bool
// Payload is the payload with original type. // Payload is the payload with original type. This may be modified after
// the call to HandleRPC which provides the InPayload returns and must be
// copied if needed later.
Payload any Payload any
// Data is the serialized message payload.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any // Length is the size of the uncompressed payload data. Does not include any
// framing (gRPC or HTTP/2). // framing (gRPC or HTTP/2).
@ -143,10 +143,10 @@ func (s *InTrailer) isRPCStats() {}
type OutPayload struct { type OutPayload struct {
// Client is true if this OutPayload is from client side. // Client is true if this OutPayload is from client side.
Client bool Client bool
// Payload is the payload with original type. // Payload is the payload with original type. This may be modified after
// the call to HandleRPC which provides the OutPayload returns and must be
// copied if needed later.
Payload any Payload any
// Data is the serialized message payload.
Data []byte
// Length is the size of the uncompressed payload data. Does not include any // Length is the size of the uncompressed payload data. Does not include any
// framing (gRPC or HTTP/2). // framing (gRPC or HTTP/2).
Length int Length int

View File

@ -23,6 +23,7 @@ import (
"errors" "errors"
"io" "io"
"math" "math"
"math/rand"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -34,13 +35,13 @@ import (
"google.golang.org/grpc/internal/balancerload" "google.golang.org/grpc/internal/balancerload"
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata" imetadata "google.golang.org/grpc/internal/metadata"
iresolver "google.golang.org/grpc/internal/resolver" iresolver "google.golang.org/grpc/internal/resolver"
"google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/serviceconfig"
istatus "google.golang.org/grpc/internal/status" istatus "google.golang.org/grpc/internal/status"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -359,7 +360,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client
cs.attempt = a cs.attempt = a
return nil return nil
} }
if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil { if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) }); err != nil {
return nil, err return nil, err
} }
@ -516,7 +517,8 @@ func (a *csAttempt) newStream() error {
return toRPCErr(nse.Err) return toRPCErr(nse.Err)
} }
a.s = s a.s = s
a.p = &parser{r: s, recvBufferPool: a.cs.cc.dopts.recvBufferPool} a.ctx = s.Context()
a.p = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool}
return nil return nil
} }
@ -567,8 +569,13 @@ type clientStream struct {
// TODO(hedging): hedging will have multiple attempts simultaneously. // TODO(hedging): hedging will have multiple attempts simultaneously.
committed bool // active attempt committed for retry? committed bool // active attempt committed for retry?
onCommit func() onCommit func()
buffer []func(a *csAttempt) error // operations to replay on retry replayBuffer []replayOp // operations to replay on retry
bufferSize int // current size of buffer replayBufferSize int // current size of replayBuffer
}
type replayOp struct {
op func(a *csAttempt) error
cleanup func()
} }
// csAttempt implements a single transport stream attempt within a // csAttempt implements a single transport stream attempt within a
@ -606,7 +613,12 @@ func (cs *clientStream) commitAttemptLocked() {
cs.onCommit() cs.onCommit()
} }
cs.committed = true cs.committed = true
cs.buffer = nil for _, op := range cs.replayBuffer {
if op.cleanup != nil {
op.cleanup()
}
}
cs.replayBuffer = nil
} }
func (cs *clientStream) commitAttempt() { func (cs *clientStream) commitAttempt() {
@ -655,13 +667,13 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) {
if len(sps) == 1 { if len(sps) == 1 {
var e error var e error
if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 { if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 {
channelz.Infof(logger, cs.cc.channelzID, "Server retry pushback specified to abort (%q).", sps[0]) channelz.Infof(logger, cs.cc.channelz, "Server retry pushback specified to abort (%q).", sps[0])
cs.retryThrottler.throttle() // This counts as a failure for throttling. cs.retryThrottler.throttle() // This counts as a failure for throttling.
return false, err return false, err
} }
hasPushback = true hasPushback = true
} else if len(sps) > 1 { } else if len(sps) > 1 {
channelz.Warningf(logger, cs.cc.channelzID, "Server retry pushback specified multiple values (%q); not retrying.", sps) channelz.Warningf(logger, cs.cc.channelz, "Server retry pushback specified multiple values (%q); not retrying.", sps)
cs.retryThrottler.throttle() // This counts as a failure for throttling. cs.retryThrottler.throttle() // This counts as a failure for throttling.
return false, err return false, err
} }
@ -698,7 +710,7 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) {
if max := float64(rp.MaxBackoff); cur > max { if max := float64(rp.MaxBackoff); cur > max {
cur = max cur = max
} }
dur = time.Duration(grpcrand.Int63n(int64(cur))) dur = time.Duration(rand.Int63n(int64(cur)))
cs.numRetriesSincePushback++ cs.numRetriesSincePushback++
} }
@ -731,7 +743,7 @@ func (cs *clientStream) retryLocked(attempt *csAttempt, lastErr error) error {
// the stream is canceled. // the stream is canceled.
return err return err
} }
// Note that the first op in the replay buffer always sets cs.attempt // Note that the first op in replayBuffer always sets cs.attempt
// if it is able to pick a transport and create a stream. // if it is able to pick a transport and create a stream.
if lastErr = cs.replayBufferLocked(attempt); lastErr == nil { if lastErr = cs.replayBufferLocked(attempt); lastErr == nil {
return nil return nil
@ -760,7 +772,7 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())
// already be status errors. // already be status errors.
return toRPCErr(op(cs.attempt)) return toRPCErr(op(cs.attempt))
} }
if len(cs.buffer) == 0 { if len(cs.replayBuffer) == 0 {
// For the first op, which controls creation of the stream and // For the first op, which controls creation of the stream and
// assigns cs.attempt, we need to create a new attempt inline // assigns cs.attempt, we need to create a new attempt inline
// before executing the first op. On subsequent ops, the attempt // before executing the first op. On subsequent ops, the attempt
@ -850,25 +862,26 @@ func (cs *clientStream) Trailer() metadata.MD {
} }
func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error { func (cs *clientStream) replayBufferLocked(attempt *csAttempt) error {
for _, f := range cs.buffer { for _, f := range cs.replayBuffer {
if err := f(attempt); err != nil { if err := f.op(attempt); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error) { func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error, cleanup func()) {
// Note: we still will buffer if retry is disabled (for transparent retries). // Note: we still will buffer if retry is disabled (for transparent retries).
if cs.committed { if cs.committed {
return return
} }
cs.bufferSize += sz cs.replayBufferSize += sz
if cs.bufferSize > cs.callInfo.maxRetryRPCBufferSize { if cs.replayBufferSize > cs.callInfo.maxRetryRPCBufferSize {
cs.commitAttemptLocked() cs.commitAttemptLocked()
cleanup()
return return
} }
cs.buffer = append(cs.buffer, op) cs.replayBuffer = append(cs.replayBuffer, replayOp{op: op, cleanup: cleanup})
} }
func (cs *clientStream) SendMsg(m any) (err error) { func (cs *clientStream) SendMsg(m any) (err error) {
@ -890,23 +903,50 @@ func (cs *clientStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp) hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.cp, cs.comp, cs.cc.dopts.copts.BufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.callInfo.maxSendMessageSize { if payloadLen > *cs.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, *cs.callInfo.maxSendMessageSize)
} }
// always take an extra ref in case data == payload (i.e. when the data isn't
// compressed). The original ref will always be freed by the deferred free above.
payload.Ref()
op := func(a *csAttempt) error { op := func(a *csAttempt) error {
return a.sendMsg(m, hdr, payload, data) return a.sendMsg(m, hdr, payload, dataLen, payloadLen)
}
// onSuccess is invoked when the op is captured for a subsequent retry. If the
// stream was established by a previous message and therefore retries are
// disabled, onSuccess will not be invoked, and payloadRef can be freed
// immediately.
onSuccessCalled := false
err = cs.withRetry(op, func() {
cs.bufferForRetryLocked(len(hdr)+payloadLen, op, payload.Free)
onSuccessCalled = true
})
if !onSuccessCalled {
payload.Free()
} }
err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
if len(cs.binlogs) != 0 && err == nil { if len(cs.binlogs) != 0 && err == nil {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
OnClientSide: true, OnClientSide: true,
Message: data, Message: data.Materialize(),
} }
for _, binlog := range cs.binlogs { for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, cm) binlog.Log(cs.ctx, cm)
@ -923,6 +963,7 @@ func (cs *clientStream) RecvMsg(m any) error {
var recvInfo *payloadInfo var recvInfo *payloadInfo
if len(cs.binlogs) != 0 { if len(cs.binlogs) != 0 {
recvInfo = &payloadInfo{} recvInfo = &payloadInfo{}
defer recvInfo.free()
} }
err := cs.withRetry(func(a *csAttempt) error { err := cs.withRetry(func(a *csAttempt) error {
return a.recvMsg(m, recvInfo) return a.recvMsg(m, recvInfo)
@ -930,7 +971,7 @@ func (cs *clientStream) RecvMsg(m any) error {
if len(cs.binlogs) != 0 && err == nil { if len(cs.binlogs) != 0 && err == nil {
sm := &binarylog.ServerMessage{ sm := &binarylog.ServerMessage{
OnClientSide: true, OnClientSide: true,
Message: recvInfo.uncompressedBytes, Message: recvInfo.uncompressedBytes.Materialize(),
} }
for _, binlog := range cs.binlogs { for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, sm) binlog.Log(cs.ctx, sm)
@ -957,7 +998,7 @@ func (cs *clientStream) CloseSend() error {
// RecvMsg. This also matches historical behavior. // RecvMsg. This also matches historical behavior.
return nil return nil
} }
cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }) cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op, nil) })
if len(cs.binlogs) != 0 { if len(cs.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{ chc := &binarylog.ClientHalfClose{
OnClientSide: true, OnClientSide: true,
@ -1033,7 +1074,7 @@ func (cs *clientStream) finish(err error) {
cs.cancel() cs.cancel()
} }
func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error { func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength, payloadLength int) error {
cs := a.cs cs := a.cs
if a.trInfo != nil { if a.trInfo != nil {
a.mu.Lock() a.mu.Lock()
@ -1051,8 +1092,10 @@ func (a *csAttempt) sendMsg(m any, hdr, payld, data []byte) error {
} }
return io.EOF return io.EOF
} }
if len(a.statsHandlers) != 0 {
for _, sh := range a.statsHandlers { for _, sh := range a.statsHandlers {
sh.HandleRPC(a.ctx, outPayload(true, m, data, payld, time.Now())) sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now()))
}
} }
if channelz.IsOn() { if channelz.IsOn() {
a.t.IncrMsgSent() a.t.IncrMsgSent()
@ -1064,6 +1107,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
cs := a.cs cs := a.cs
if len(a.statsHandlers) != 0 && payInfo == nil { if len(a.statsHandlers) != 0 && payInfo == nil {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
if !a.decompSet { if !a.decompSet {
@ -1082,8 +1126,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
a.decompSet = true a.decompSet = true
} }
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp) if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decomp, false); err != nil {
if err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil { if statusErr := a.s.Status().Err(); statusErr != nil {
return statusErr return statusErr
@ -1105,11 +1148,9 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
Client: true, Client: true,
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
// TODO truncate large payload.
Data: payInfo.uncompressedBytes,
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
Length: len(payInfo.uncompressedBytes), Length: payInfo.uncompressedBytes.Len(),
}) })
} }
if channelz.IsOn() { if channelz.IsOn() {
@ -1121,14 +1162,12 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) {
} }
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp) if err := recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp, false); err == io.EOF {
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
return a.s.Status().Err() // non-server streaming Recv returns nil on success return a.s.Status().Err() // non-server streaming Recv returns nil on success
} } else if err != nil {
return toRPCErr(err) return toRPCErr(err)
}
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
func (a *csAttempt) finish(err error) { func (a *csAttempt) finish(err error) {
@ -1184,12 +1223,12 @@ func (a *csAttempt) finish(err error) {
a.mu.Unlock() a.mu.Unlock()
} }
// newClientStream creates a ClientStream with the specified transport, on the // newNonRetryClientStream creates a ClientStream with the specified transport, on the
// given addrConn. // given addrConn.
// //
// It's expected that the given transport is either the same one in addrConn, or // It's expected that the given transport is either the same one in addrConn, or
// is already closed. To avoid race, transport is specified separately, instead // is already closed. To avoid race, transport is specified separately, instead
// of using ac.transpot. // of using ac.transport.
// //
// Main difference between this and ClientConn.NewStream: // Main difference between this and ClientConn.NewStream:
// - no retry // - no retry
@ -1275,7 +1314,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
return nil, err return nil, err
} }
as.s = s as.s = s
as.p = &parser{r: s, recvBufferPool: ac.dopts.recvBufferPool} as.p = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool}
ac.incrCallsStarted() ac.incrCallsStarted()
if desc != unaryStreamDesc { if desc != unaryStreamDesc {
// Listen on stream context to cleanup when the stream context is // Listen on stream context to cleanup when the stream context is
@ -1372,17 +1411,26 @@ func (as *addrConnStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp) hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.cp, as.comp, as.ac.dopts.copts.BufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payld) > *as.callInfo.maxSendMessageSize { if payload.Len() > *as.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payload.Len(), *as.callInfo.maxSendMessageSize)
} }
if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil { if err := as.t.Write(as.s, hdr, payload, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
if !as.desc.ClientStreams { if !as.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on error // For non-client-streaming RPCs, we return nil instead of EOF on error
// because the generated code requires it. finish is not called; RecvMsg() // because the generated code requires it. finish is not called; RecvMsg()
@ -1422,8 +1470,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
as.decompSet = true as.decompSet = true
} }
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err != nil {
if err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := as.s.Status().Err(); statusErr != nil { if statusErr := as.s.Status().Err(); statusErr != nil {
return statusErr return statusErr
@ -1443,14 +1490,12 @@ func (as *addrConnStream) RecvMsg(m any) (err error) {
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) if err := recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp, false); err == io.EOF {
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
if err == io.EOF {
return as.s.Status().Err() // non-server streaming Recv returns nil on success return as.s.Status().Err() // non-server streaming Recv returns nil on success
} } else if err != nil {
return toRPCErr(err) return toRPCErr(err)
}
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
func (as *addrConnStream) finish(err error) { func (as *addrConnStream) finish(err error) {
@ -1644,18 +1689,31 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp) hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.cp, ss.comp, ss.p.bufferPool)
if err != nil { if err != nil {
return err return err
} }
defer func() {
data.Free()
// only free payload if compression was made, and therefore it is a different set
// of buffers from data.
if pf.isCompressed() {
payload.Free()
}
}()
dataLen := data.Len()
payloadLen := payload.Len()
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize { if payloadLen > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", payloadLen, ss.maxSendMessageSize)
} }
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
if !ss.serverHeaderBinlogged { if !ss.serverHeaderBinlogged {
h, _ := ss.s.Header() h, _ := ss.s.Header()
@ -1668,7 +1726,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
} }
sm := &binarylog.ServerMessage{ sm := &binarylog.ServerMessage{
Message: data, Message: data.Materialize(),
} }
for _, binlog := range ss.binlogs { for _, binlog := range ss.binlogs {
binlog.Log(ss.ctx, sm) binlog.Log(ss.ctx, sm)
@ -1676,7 +1734,7 @@ func (ss *serverStream) SendMsg(m any) (err error) {
} }
if len(ss.statsHandler) != 0 { if len(ss.statsHandler) != 0 {
for _, sh := range ss.statsHandler { for _, sh := range ss.statsHandler {
sh.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now())) sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now()))
} }
} }
return nil return nil
@ -1713,8 +1771,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
var payInfo *payloadInfo var payInfo *payloadInfo
if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 {
payInfo = &payloadInfo{} payInfo = &payloadInfo{}
defer payInfo.free()
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp); err != nil { if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, payInfo, ss.decomp, true); err != nil {
if err == io.EOF { if err == io.EOF {
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
chc := &binarylog.ClientHalfClose{} chc := &binarylog.ClientHalfClose{}
@ -1734,9 +1793,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
sh.HandleRPC(ss.s.Context(), &stats.InPayload{ sh.HandleRPC(ss.s.Context(), &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: m, Payload: m,
// TODO truncate large payload. Length: payInfo.uncompressedBytes.Len(),
Data: payInfo.uncompressedBytes,
Length: len(payInfo.uncompressedBytes),
WireLength: payInfo.compressedLength + headerLen, WireLength: payInfo.compressedLength + headerLen,
CompressedLength: payInfo.compressedLength, CompressedLength: payInfo.compressedLength,
}) })
@ -1744,7 +1801,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) {
} }
if len(ss.binlogs) != 0 { if len(ss.binlogs) != 0 {
cm := &binarylog.ClientMessage{ cm := &binarylog.ClientMessage{
Message: payInfo.uncompressedBytes, Message: payInfo.uncompressedBytes.Materialize(),
} }
for _, binlog := range ss.binlogs { for _, binlog := range ss.binlogs {
binlog.Log(ss.ctx, cm) binlog.Log(ss.ctx, cm)
@ -1759,23 +1816,26 @@ func MethodFromServerStream(stream ServerStream) (string, bool) {
return Method(stream.Context()) return Method(stream.Context())
} }
// prepareMsg returns the hdr, payload and data // prepareMsg returns the hdr, payload and data using the compressors passed or
// using the compressors passed or using the // using the passed preparedmsg. The returned boolean indicates whether
// passed preparedmsg // compression was made and therefore whether the payload needs to be freed in
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) { // addition to the returned data. Freeing the payload if the returned boolean is
// false can lead to undefined behavior.
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) {
if preparedMsg, ok := m.(*PreparedMsg); ok { if preparedMsg, ok := m.(*PreparedMsg); ok {
return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil return preparedMsg.hdr, preparedMsg.encodedData, preparedMsg.payload, preparedMsg.pf, nil
} }
// The input interface is not a prepared msg. // The input interface is not a prepared msg.
// Marshal and Compress the data at this point // Marshal and Compress the data at this point
data, err = encode(codec, m) data, err = encode(codec, m)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, 0, err
} }
compData, err := compress(data, cp, comp) compData, pf, err := compress(data, cp, comp, pool)
if err != nil { if err != nil {
return nil, nil, nil, err data.Free()
return nil, nil, nil, 0, err
} }
hdr, payload = msgHeader(data, compData) hdr, payload = msgHeader(data, compData, pf)
return hdr, payload, data, nil return hdr, data, payload, pf, nil
} }

152
vendor/google.golang.org/grpc/stream_interfaces.go generated vendored Normal file
View File

@ -0,0 +1,152 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
// ServerStreamingClient represents the client side of a server-streaming (one
// request, many responses) RPC. It is generic over the type of the response
// message. It is used in generated code.
type ServerStreamingClient[Res any] interface {
Recv() (*Res, error)
ClientStream
}
// ServerStreamingServer represents the server side of a server-streaming (one
// request, many responses) RPC. It is generic over the type of the response
// message. It is used in generated code.
type ServerStreamingServer[Res any] interface {
Send(*Res) error
ServerStream
}
// ClientStreamingClient represents the client side of a client-streaming (many
// requests, one response) RPC. It is generic over both the type of the request
// message stream and the type of the unary response message. It is used in
// generated code.
type ClientStreamingClient[Req any, Res any] interface {
Send(*Req) error
CloseAndRecv() (*Res, error)
ClientStream
}
// ClientStreamingServer represents the server side of a client-streaming (many
// requests, one response) RPC. It is generic over both the type of the request
// message stream and the type of the unary response message. It is used in
// generated code.
type ClientStreamingServer[Req any, Res any] interface {
Recv() (*Req, error)
SendAndClose(*Res) error
ServerStream
}
// BidiStreamingClient represents the client side of a bidirectional-streaming
// (many requests, many responses) RPC. It is generic over both the type of the
// request message stream and the type of the response message stream. It is
// used in generated code.
type BidiStreamingClient[Req any, Res any] interface {
Send(*Req) error
Recv() (*Res, error)
ClientStream
}
// BidiStreamingServer represents the server side of a bidirectional-streaming
// (many requests, many responses) RPC. It is generic over both the type of the
// request message stream and the type of the response message stream. It is
// used in generated code.
type BidiStreamingServer[Req any, Res any] interface {
Recv() (*Req, error)
Send(*Res) error
ServerStream
}
// GenericClientStream implements the ServerStreamingClient, ClientStreamingClient,
// and BidiStreamingClient interfaces. It is used in generated code.
type GenericClientStream[Req any, Res any] struct {
ClientStream
}
var _ ServerStreamingClient[string] = (*GenericClientStream[int, string])(nil)
var _ ClientStreamingClient[int, string] = (*GenericClientStream[int, string])(nil)
var _ BidiStreamingClient[int, string] = (*GenericClientStream[int, string])(nil)
// Send pushes one message into the stream of requests to be consumed by the
// server. The type of message which can be sent is determined by the Req type
// parameter of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) Send(m *Req) error {
return x.ClientStream.SendMsg(m)
}
// Recv reads one message from the stream of responses generated by the server.
// The type of the message returned is determined by the Res type parameter
// of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) Recv() (*Res, error) {
m := new(Res)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// CloseAndRecv closes the sending side of the stream, then receives the unary
// response from the server. The type of message which it returns is determined
// by the Res type parameter of the GenericClientStream receiver.
func (x *GenericClientStream[Req, Res]) CloseAndRecv() (*Res, error) {
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
m := new(Res)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// GenericServerStream implements the ServerStreamingServer, ClientStreamingServer,
// and BidiStreamingServer interfaces. It is used in generated code.
type GenericServerStream[Req any, Res any] struct {
ServerStream
}
var _ ServerStreamingServer[string] = (*GenericServerStream[int, string])(nil)
var _ ClientStreamingServer[int, string] = (*GenericServerStream[int, string])(nil)
var _ BidiStreamingServer[int, string] = (*GenericServerStream[int, string])(nil)
// Send pushes one message into the stream of responses to be consumed by the
// client. The type of message which can be sent is determined by the Res
// type parameter of the serverStreamServer receiver.
func (x *GenericServerStream[Req, Res]) Send(m *Res) error {
return x.ServerStream.SendMsg(m)
}
// SendAndClose pushes the unary response to the client. The type of message
// which can be sent is determined by the Res type parameter of the
// clientStreamServer receiver.
func (x *GenericServerStream[Req, Res]) SendAndClose(m *Res) error {
return x.ServerStream.SendMsg(m)
}
// Recv reads one message from the stream of requests generated by the client.
// The type of the message returned is determined by the Req type parameter
// of the clientStreamServer receiver.
func (x *GenericServerStream[Req, Res]) Recv() (*Req, error) {
m := new(Req)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}

View File

@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.62.0" const Version = "1.66.2"

190
vendor/google.golang.org/grpc/vet.sh generated vendored
View File

@ -1,190 +0,0 @@
#!/bin/bash
set -ex # Exit on error; debugging enabled.
set -o pipefail # Fail a pipe if any sub-command fails.
# not makes sure the command passed to it does not exit with a return code of 0.
not() {
# This is required instead of the earlier (! $COMMAND) because subshells and
# pipefail don't work the same on Darwin as in Linux.
! "$@"
}
die() {
echo "$@" >&2
exit 1
}
fail_on_output() {
tee /dev/stderr | not read
}
# Check to make sure it's safe to modify the user's git repo.
git status --porcelain | fail_on_output
# Undo any edits made by this script.
cleanup() {
git reset --hard HEAD
}
trap cleanup EXIT
PATH="${HOME}/go/bin:${GOROOT}/bin:${PATH}"
go version
if [[ "$1" = "-install" ]]; then
# Install the pinned versions as defined in module tools.
pushd ./test/tools
go install \
golang.org/x/tools/cmd/goimports \
honnef.co/go/tools/cmd/staticcheck \
github.com/client9/misspell/cmd/misspell
popd
if [[ -z "${VET_SKIP_PROTO}" ]]; then
if [[ "${GITHUB_ACTIONS}" = "true" ]]; then
PROTOBUF_VERSION=25.2 # a.k.a. v4.22.0 in pb.go files.
PROTOC_FILENAME=protoc-${PROTOBUF_VERSION}-linux-x86_64.zip
pushd /home/runner/go
wget https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/${PROTOC_FILENAME}
unzip ${PROTOC_FILENAME}
bin/protoc --version
popd
elif not which protoc > /dev/null; then
die "Please install protoc into your path"
fi
fi
exit 0
elif [[ "$#" -ne 0 ]]; then
die "Unknown argument(s): $*"
fi
# - Check that generated proto files are up to date.
if [[ -z "${VET_SKIP_PROTO}" ]]; then
make proto && git status --porcelain 2>&1 | fail_on_output || \
(git status; git --no-pager diff; exit 1)
fi
if [[ -n "${VET_ONLY_PROTO}" ]]; then
exit 0
fi
# - Ensure all source files contain a copyright message.
# (Done in two parts because Darwin "git grep" has broken support for compound
# exclusion matches.)
(grep -L "DO NOT EDIT" $(git grep -L "\(Copyright [0-9]\{4,\} gRPC authors\)" -- '*.go') || true) | fail_on_output
# - Make sure all tests in grpc and grpc/test use leakcheck via Teardown.
not grep 'func Test[^(]' *_test.go
not grep 'func Test[^(]' test/*.go
# - Check for typos in test function names
git grep 'func (s) ' -- "*_test.go" | not grep -v 'func (s) Test'
git grep 'func [A-Z]' -- "*_test.go" | not grep -v 'func Test\|Benchmark\|Example'
# - Do not import x/net/context.
not git grep -l 'x/net/context' -- "*.go"
# - Do not import math/rand for real library code. Use internal/grpcrand for
# thread safety.
git grep -l '"math/rand"' -- "*.go" 2>&1 | not grep -v '^examples\|^interop/stress\|grpcrand\|^benchmark\|wrr_test'
# - Do not use "interface{}"; use "any" instead.
git grep -l 'interface{}' -- "*.go" 2>&1 | not grep -v '\.pb\.go\|protoc-gen-go-grpc\|grpc_testing_not_regenerate'
# - Do not call grpclog directly. Use grpclog.Component instead.
git grep -l -e 'grpclog.I' --or -e 'grpclog.W' --or -e 'grpclog.E' --or -e 'grpclog.F' --or -e 'grpclog.V' -- "*.go" | not grep -v '^grpclog/component.go\|^internal/grpctest/tlogger_test.go'
# - Ensure all ptypes proto packages are renamed when importing.
not git grep "\(import \|^\s*\)\"github.com/golang/protobuf/ptypes/" -- "*.go"
# - Ensure all usages of grpc_testing package are renamed when importing.
not git grep "\(import \|^\s*\)\"google.golang.org/grpc/interop/grpc_testing" -- "*.go"
# - Ensure all xds proto imports are renamed to *pb or *grpc.
git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.pb.go' | not grep -v 'pb "\|grpc "'
misspell -error .
# - gofmt, goimports, go vet, go mod tidy.
# Perform these checks on each module inside gRPC.
for MOD_FILE in $(find . -name 'go.mod'); do
MOD_DIR=$(dirname ${MOD_FILE})
pushd ${MOD_DIR}
go vet -all ./... | fail_on_output
gofmt -s -d -l . 2>&1 | fail_on_output
goimports -l . 2>&1 | not grep -vE "\.pb\.go"
go mod tidy -compat=1.19
git status --porcelain 2>&1 | fail_on_output || \
(git status; git --no-pager diff; exit 1)
popd
done
# - Collection of static analysis checks
SC_OUT="$(mktemp)"
staticcheck -go 1.19 -checks 'all' ./... > "${SC_OUT}" || true
# Error for anything other than checks that need exclusions.
grep -v "(ST1000)" "${SC_OUT}" | grep -v "(SA1019)" | grep -v "(ST1003)" | not grep -v "(ST1019)\|\(other import of\)"
# Exclude underscore checks for generated code.
grep "(ST1003)" "${SC_OUT}" | not grep -v '\(.pb.go:\)\|\(code_string_test.go:\)\|\(grpc_testing_not_regenerate\)'
# Error for duplicate imports not including grpc protos.
grep "(ST1019)\|\(other import of\)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
channelz/grpc_channelz_v1"
go-control-plane/envoy
grpclb/grpc_lb_v1"
health/grpc_health_v1"
interop/grpc_testing"
orca/v3"
proto/grpc_gcp"
proto/grpc_lookup_v1"
reflection/grpc_reflection_v1"
reflection/grpc_reflection_v1alpha"
XXXXX PleaseIgnoreUnused'
# Error for any package comments not in generated code.
grep "(ST1000)" "${SC_OUT}" | not grep -v "\.pb\.go:"
# Only ignore the following deprecated types/fields/functions and exclude
# generated code.
grep "(SA1019)" "${SC_OUT}" | not grep -Fv 'XXXXX PleaseIgnoreUnused
XXXXX Protobuf related deprecation errors:
"github.com/golang/protobuf
.pb.go:
grpc_testing_not_regenerate
: ptypes.
proto.RegisterType
XXXXX gRPC internal usage deprecation errors:
"google.golang.org/grpc
: grpc.
: v1alpha.
: v1alphareflectionpb.
BalancerAttributes is deprecated:
CredsBundle is deprecated:
Metadata is deprecated: use Attributes instead.
NewSubConn is deprecated:
OverrideServerName is deprecated:
RemoveSubConn is deprecated:
SecurityVersion is deprecated:
Target is deprecated: Use the Target field in the BuildOptions instead.
UpdateAddresses is deprecated:
UpdateSubConnState is deprecated:
balancer.ErrTransientFailure is deprecated:
grpc/reflection/v1alpha/reflection.proto
XXXXX xDS deprecated fields we support
.ExactMatch
.PrefixMatch
.SafeRegexMatch
.SuffixMatch
GetContainsMatch
GetExactMatch
GetMatchSubjectAltNames
GetPrefixMatch
GetSafeRegexMatch
GetSuffixMatch
GetTlsCertificateCertificateProviderInstance
GetValidationContextCertificateProviderInstance
XXXXX PleaseIgnoreUnused'
echo SUCCESS

18
vendor/modules.txt vendored
View File

@ -406,21 +406,22 @@ golang.org/x/text/width
# golang.org/x/time v0.6.0 # golang.org/x/time v0.6.0
## explicit; go 1.18 ## explicit; go 1.18
golang.org/x/time/rate golang.org/x/time/rate
# google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 # google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117
## explicit; go 1.19 ## explicit; go 1.20
google.golang.org/genproto/googleapis/api/httpbody google.golang.org/genproto/googleapis/api/httpbody
# google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 # google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117
## explicit; go 1.19 ## explicit; go 1.20
google.golang.org/genproto/googleapis/rpc/errdetails google.golang.org/genproto/googleapis/rpc/errdetails
google.golang.org/genproto/googleapis/rpc/status google.golang.org/genproto/googleapis/rpc/status
# google.golang.org/grpc v1.62.0 # google.golang.org/grpc v1.66.2
## explicit; go 1.19 ## explicit; go 1.21
google.golang.org/grpc google.golang.org/grpc
google.golang.org/grpc/attributes google.golang.org/grpc/attributes
google.golang.org/grpc/backoff google.golang.org/grpc/backoff
google.golang.org/grpc/balancer google.golang.org/grpc/balancer
google.golang.org/grpc/balancer/base google.golang.org/grpc/balancer/base
google.golang.org/grpc/balancer/grpclb/state google.golang.org/grpc/balancer/grpclb/state
google.golang.org/grpc/balancer/pickfirst
google.golang.org/grpc/balancer/roundrobin google.golang.org/grpc/balancer/roundrobin
google.golang.org/grpc/binarylog/grpc_binarylog_v1 google.golang.org/grpc/binarylog/grpc_binarylog_v1
google.golang.org/grpc/channelz google.golang.org/grpc/channelz
@ -431,7 +432,9 @@ google.golang.org/grpc/credentials/insecure
google.golang.org/grpc/encoding google.golang.org/grpc/encoding
google.golang.org/grpc/encoding/gzip google.golang.org/grpc/encoding/gzip
google.golang.org/grpc/encoding/proto google.golang.org/grpc/encoding/proto
google.golang.org/grpc/experimental/stats
google.golang.org/grpc/grpclog google.golang.org/grpc/grpclog
google.golang.org/grpc/grpclog/internal
google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/health/grpc_health_v1
google.golang.org/grpc/internal google.golang.org/grpc/internal
google.golang.org/grpc/internal/backoff google.golang.org/grpc/internal/backoff
@ -443,7 +446,6 @@ google.golang.org/grpc/internal/channelz
google.golang.org/grpc/internal/credentials google.golang.org/grpc/internal/credentials
google.golang.org/grpc/internal/envconfig google.golang.org/grpc/internal/envconfig
google.golang.org/grpc/internal/grpclog google.golang.org/grpc/internal/grpclog
google.golang.org/grpc/internal/grpcrand
google.golang.org/grpc/internal/grpcsync google.golang.org/grpc/internal/grpcsync
google.golang.org/grpc/internal/grpcutil google.golang.org/grpc/internal/grpcutil
google.golang.org/grpc/internal/idle google.golang.org/grpc/internal/idle
@ -455,11 +457,13 @@ google.golang.org/grpc/internal/resolver/dns/internal
google.golang.org/grpc/internal/resolver/passthrough google.golang.org/grpc/internal/resolver/passthrough
google.golang.org/grpc/internal/resolver/unix google.golang.org/grpc/internal/resolver/unix
google.golang.org/grpc/internal/serviceconfig google.golang.org/grpc/internal/serviceconfig
google.golang.org/grpc/internal/stats
google.golang.org/grpc/internal/status google.golang.org/grpc/internal/status
google.golang.org/grpc/internal/syscall google.golang.org/grpc/internal/syscall
google.golang.org/grpc/internal/transport google.golang.org/grpc/internal/transport
google.golang.org/grpc/internal/transport/networktype google.golang.org/grpc/internal/transport/networktype
google.golang.org/grpc/keepalive google.golang.org/grpc/keepalive
google.golang.org/grpc/mem
google.golang.org/grpc/metadata google.golang.org/grpc/metadata
google.golang.org/grpc/peer google.golang.org/grpc/peer
google.golang.org/grpc/resolver google.golang.org/grpc/resolver