1
0
mirror of https://github.com/etcd-io/etcd.git synced 2026-02-05 15:46:51 +01:00
Files
etcd/cache/cache.go
Peter Chang ed1d377da9 cache: make cache progress-aware
Signed-off-by: Peter Chang <peter.yaochen.chang@gmail.com>
2025-10-01 20:21:30 +00:00

412 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright 2025 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"bytes"
"context"
"errors"
"fmt"
"sync"
"time"
pb "go.etcd.io/etcd/api/v3/etcdserverpb"
clientv3 "go.etcd.io/etcd/client/v3"
)
var (
// Returned when an option combination isnt yet handled by the cache (e.g. WithPrevKV, WithProgressNotify for Watch(), WithCountOnly for Get()).
ErrUnsupportedRequest = errors.New("cache: unsupported request parameters")
// Returned when the requested key or keyrange is invalid (empty or reversed) or lies outside c.prefix.
ErrKeyRangeInvalid = errors.New("cache: invalid or outofrange key range")
)
// Cache buffers a single etcd Watch for a given keyprefix and fanouts local watchers.
type Cache struct {
prefix string // prefix is the key-prefix this shard is responsible for ("" = root).
cfg Config // immutable runtime configuration
watcher clientv3.Watcher
kv clientv3.KV
demux *demux // demux fans incoming events out to active watchers and manages resync.
store *store // lastobserved snapshot
ready *ready
stop context.CancelFunc
waitGroup sync.WaitGroup
internalCtx context.Context
}
// New builds a cache shard that watches only the requested prefix.
// For the root cache pass "".
func New(client *clientv3.Client, prefix string, opts ...Option) (*Cache, error) {
cfg := defaultConfig()
for _, opt := range opts {
opt(&cfg)
}
if cfg.HistoryWindowSize <= 0 {
return nil, fmt.Errorf("invalid HistoryWindowSize %d (must be > 0)", cfg.HistoryWindowSize)
}
if cfg.BTreeDegree < 2 {
return nil, fmt.Errorf("invalid BTreeDegree %d (must be >= 2)", cfg.BTreeDegree)
}
internalCtx, cancel := context.WithCancel(context.Background())
cache := &Cache{
prefix: prefix,
cfg: cfg,
watcher: client.Watcher,
kv: client.KV,
store: newStore(cfg.BTreeDegree, cfg.HistoryWindowSize),
ready: newReady(),
stop: cancel,
internalCtx: internalCtx,
}
cache.demux = NewDemux(internalCtx, &cache.waitGroup, cfg.HistoryWindowSize, cfg.ResyncInterval)
cache.waitGroup.Add(1)
go func() {
defer cache.waitGroup.Done()
cache.getWatchLoop()
}()
return cache, nil
}
// Watch registers a cache-backed watcher for a given key or prefix.
// It returns a WatchChan that streams WatchResponses containing events.
func (c *Cache) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
if err := c.WaitReady(ctx); err != nil {
emptyWatchChan := make(chan clientv3.WatchResponse)
close(emptyWatchChan)
return emptyWatchChan
}
op := clientv3.OpWatch(key, opts...)
startRev := op.Rev()
pred, err := c.validateWatch(key, op)
if err != nil {
ch := make(chan clientv3.WatchResponse, 1)
ch <- clientv3.WatchResponse{Canceled: true, CancelReason: err.Error()}
close(ch)
return ch
}
w := newWatcher(c.cfg.PerWatcherBufferSize, pred)
c.demux.Register(w, startRev)
responseChan := make(chan clientv3.WatchResponse)
c.waitGroup.Add(1)
go func() {
defer c.waitGroup.Done()
defer close(responseChan)
defer c.demux.Unregister(w)
for {
select {
case <-ctx.Done():
return
case <-c.internalCtx.Done():
return
case resp, ok := <-w.respCh:
if !ok {
if w.cancelResp != nil {
select {
case <-ctx.Done():
case <-c.internalCtx.Done():
case responseChan <- *w.cancelResp:
}
}
return
}
select {
case <-ctx.Done():
return
case <-c.internalCtx.Done():
return
case responseChan <- resp:
}
}
}
}()
return responseChan
}
func (c *Cache) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
if c.store.LatestRev() == 0 {
if err := c.WaitReady(ctx); err != nil {
return nil, err
}
}
op := clientv3.OpGet(key, opts...)
if _, err := c.validateGet(key, op); err != nil {
return nil, err
}
startKey := []byte(key)
endKey := op.RangeBytes()
requestedRev := op.Rev()
kvs, latestRev, err := c.store.Get(startKey, endKey, requestedRev)
if err != nil {
return nil, err
}
return &clientv3.GetResponse{
Header: &pb.ResponseHeader{Revision: latestRev},
Kvs: kvs,
Count: int64(len(kvs)),
}, nil
}
// Ready returns true if the snapshot has been loaded and the first watch has been confirmed.
func (c *Cache) Ready() bool {
return c.ready.Ready()
}
// WaitReady blocks until the cache is ready or the ctx is cancelled.
func (c *Cache) WaitReady(ctx context.Context) error {
return c.ready.WaitReady(ctx)
}
func (c *Cache) WaitForRevision(ctx context.Context, rev int64) error {
for {
if c.store.LatestRev() >= rev {
return nil
}
select {
case <-time.After(10 * time.Millisecond):
case <-ctx.Done():
return ctx.Err()
}
}
}
// Close cancels the private context and blocks until all goroutines return.
func (c *Cache) Close() {
c.stop()
c.waitGroup.Wait()
}
func (c *Cache) getWatchLoop() {
cfg := defaultConfig()
ctx := c.internalCtx
backoff := cfg.InitialBackoff
for {
if err := ctx.Err(); err != nil {
return
}
if err := c.getWatch(); err != nil {
fmt.Printf("getWatch failed, will retry after %v: %v\n", backoff, err)
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
}
}
}
func (c *Cache) getWatch() error {
getResp, err := c.get(c.internalCtx)
if err != nil {
return err
}
return c.watch(getResp.Header.Revision + 1)
}
func (c *Cache) get(ctx context.Context) (*clientv3.GetResponse, error) {
resp, err := c.kv.Get(ctx, c.prefix, clientv3.WithPrefix())
if err != nil {
return nil, err
}
c.store.Restore(resp.Kvs, resp.Header.Revision)
return resp, nil
}
func (c *Cache) watch(rev int64) error {
readyOnce := sync.Once{}
for {
storeW := newWatcher(c.cfg.PerWatcherBufferSize, nil)
c.demux.Register(storeW, rev)
applyErr := make(chan error, 1)
c.waitGroup.Add(1)
go func() {
defer c.waitGroup.Done()
if err := c.applyStorage(storeW); err != nil {
applyErr <- err
}
close(applyErr)
}()
err := c.watchEvents(rev, applyErr, &readyOnce)
c.demux.Unregister(storeW)
if err != nil {
return err
}
}
}
func (c *Cache) applyStorage(storeW *watcher) error {
for {
select {
case <-c.internalCtx.Done():
return nil
case resp, ok := <-storeW.respCh:
if !ok {
return nil
}
if err := c.store.Apply(resp); err != nil {
return err
}
}
}
}
func (c *Cache) watchEvents(rev int64, applyErr <-chan error, readyOnce *sync.Once) error {
watchCh := c.watcher.Watch(
c.internalCtx,
c.prefix,
clientv3.WithPrefix(),
clientv3.WithRev(rev),
clientv3.WithProgressNotify(),
clientv3.WithCreatedNotify(),
)
for {
select {
case <-c.internalCtx.Done():
return c.internalCtx.Err()
case resp, ok := <-watchCh:
if !ok {
return nil
}
readyOnce.Do(func() {
c.demux.Init(rev)
c.ready.Set()
})
if err := resp.Err(); err != nil {
c.ready.Reset()
return err
}
err := c.demux.Broadcast(resp)
if err != nil {
c.ready.Reset()
return err
}
case err := <-applyErr:
c.ready.Reset()
return err
}
}
}
func (c *Cache) validateWatch(key string, op clientv3.Op) (pred KeyPredicate, err error) {
switch {
case op.IsPrevKV():
return nil, fmt.Errorf("%w: PrevKV not supported", ErrUnsupportedRequest)
case op.IsFragment():
return nil, fmt.Errorf("%w: Fragment not supported", ErrUnsupportedRequest)
case op.IsProgressNotify():
return nil, fmt.Errorf("%w: ProgressNotify not supported", ErrUnsupportedRequest)
case op.IsCreatedNotify():
return nil, fmt.Errorf("%w: CreatedNotify not supported", ErrUnsupportedRequest)
case op.IsFilterPut():
return nil, fmt.Errorf("%w: FilterPut not supported", ErrUnsupportedRequest)
case op.IsFilterDelete():
return nil, fmt.Errorf("%w: FilterDelete not supported", ErrUnsupportedRequest)
}
startKey := []byte(key)
endKey := op.RangeBytes() // nil = single key, {0}=FromKey, else explicit range
if err := c.validateRange(startKey, endKey); err != nil {
return nil, err
}
return KeyPredForRange(startKey, endKey), nil
}
func (c *Cache) validateGet(key string, op clientv3.Op) (KeyPredicate, error) {
switch {
case op.IsCountOnly():
return nil, fmt.Errorf("%w: CountOnly not supported", ErrUnsupportedRequest)
case op.IsPrevKV():
return nil, fmt.Errorf("%w: PrevKV not supported", ErrUnsupportedRequest)
case op.IsSortSet():
return nil, fmt.Errorf("%w: SortSet not supported", ErrUnsupportedRequest)
case op.Limit() != 0:
return nil, fmt.Errorf("%w: Limit(%d) not supported", ErrUnsupportedRequest, op.Limit())
case op.MinModRev() != 0:
return nil, fmt.Errorf("%w: MinModRev(%d) not supported", ErrUnsupportedRequest, op.MinModRev())
case op.MaxModRev() != 0:
return nil, fmt.Errorf("%w: MaxModRev(%d) not supported", ErrUnsupportedRequest, op.MaxModRev())
case op.MinCreateRev() != 0:
return nil, fmt.Errorf("%w: MinCreateRev(%d) not supported", ErrUnsupportedRequest, op.MinCreateRev())
case op.MaxCreateRev() != 0:
return nil, fmt.Errorf("%w: MaxCreateRev(%d) not supported", ErrUnsupportedRequest, op.MaxCreateRev())
// cache now only serves serializable reads of the latest revision (rev == 0).
case !op.IsSerializable():
return nil, fmt.Errorf("%w: non-serializable request", ErrUnsupportedRequest)
}
startKey := []byte(key)
endKey := op.RangeBytes()
if err := c.validateRange(startKey, endKey); err != nil {
return nil, err
}
return KeyPredForRange(startKey, endKey), nil
}
func (c *Cache) validateRange(startKey, endKey []byte) error {
prefixStart := []byte(c.prefix)
prefixEnd := []byte(clientv3.GetPrefixRangeEnd(c.prefix))
isSingleKey := len(endKey) == 0
isFromKey := len(endKey) == 1 && endKey[0] == 0
switch {
case isSingleKey:
if c.prefix == "" {
return nil
}
if bytes.Compare(startKey, prefixStart) < 0 || bytes.Compare(startKey, prefixEnd) >= 0 {
return ErrKeyRangeInvalid
}
return nil
case isFromKey:
if c.prefix != "" {
return ErrKeyRangeInvalid
}
return nil
default:
if bytes.Compare(endKey, startKey) <= 0 {
return ErrKeyRangeInvalid
}
if c.prefix == "" {
return nil
}
if bytes.Compare(startKey, prefixStart) < 0 || bytes.Compare(endKey, prefixEnd) > 0 {
return ErrKeyRangeInvalid
}
return nil
}
}