From c1d8a4a8503cb6010d49ea7b621e73f7cbc6c62f Mon Sep 17 00:00:00 2001 From: Kyle Sanderson Date: Wed, 18 Dec 2024 09:15:06 +1300 Subject: [PATCH] feat(cache): implement TTLCache and TimeCache (#1822) * feat(pkg): implement ttlcache and timecache --- pkg/regexcache/regex.go | 40 +++--- pkg/timecache/timecache.go | 74 ++++++++++ pkg/timecache/timecache_test.go | 49 +++++++ pkg/ttlcache/domain.go | 41 ++++++ pkg/ttlcache/expiration.go | 76 ++++++++++ pkg/ttlcache/internal.go | 111 +++++++++++++++ pkg/ttlcache/ttlcache.go | 127 +++++++++++++++++ pkg/ttlcache/ttlcache_test.go | 245 ++++++++++++++++++++++++++++++++ 8 files changed, 742 insertions(+), 21 deletions(-) create mode 100644 pkg/timecache/timecache.go create mode 100644 pkg/timecache/timecache_test.go create mode 100644 pkg/ttlcache/domain.go create mode 100644 pkg/ttlcache/expiration.go create mode 100644 pkg/ttlcache/internal.go create mode 100644 pkg/ttlcache/ttlcache.go create mode 100644 pkg/ttlcache/ttlcache_test.go diff --git a/pkg/regexcache/regex.go b/pkg/regexcache/regex.go index 581d165..3453ae2 100644 --- a/pkg/regexcache/regex.go +++ b/pkg/regexcache/regex.go @@ -7,21 +7,19 @@ import ( "regexp" "time" - "github.com/jellydator/ttlcache/v3" + "github.com/autobrr/autobrr/pkg/ttlcache" ) var cache = ttlcache.New[string, *regexp.Regexp]( - ttlcache.WithTTL[string, *regexp.Regexp](5 * time.Minute), + ttlcache.Options[string, *regexp.Regexp]{}. + SetTimerResolution(5 * time.Minute). + SetDefaultTTL(15 * time.Minute), ) -func init() { - go cache.Start() -} - func MustCompilePOSIX(pattern string) *regexp.Regexp { - item := cache.Get(pattern) - if item != nil { - return item.Value() + item, ok := cache.Get(pattern) + if ok { + return item } reg := regexp.MustCompilePOSIX(pattern) @@ -30,9 +28,9 @@ func MustCompilePOSIX(pattern string) *regexp.Regexp { } func MustCompile(pattern string) *regexp.Regexp { - item := cache.Get(pattern) - if item != nil { - return item.Value() + item, ok := cache.Get(pattern) + if ok { + return item } reg := regexp.MustCompile(pattern) @@ -41,9 +39,9 @@ func MustCompile(pattern string) *regexp.Regexp { } func CompilePOSIX(pattern string) (*regexp.Regexp, error) { - item := cache.Get(pattern) - if item != nil { - return item.Value(), nil + item, ok := cache.Get(pattern) + if ok { + return item, nil } reg, err := regexp.CompilePOSIX(pattern) @@ -56,9 +54,9 @@ func CompilePOSIX(pattern string) (*regexp.Regexp, error) { } func Compile(pattern string) (*regexp.Regexp, error) { - item := cache.Get(pattern) - if item != nil { - return item.Value(), nil + item, ok := cache.Get(pattern) + if ok { + return item, nil } reg, err := regexp.Compile(pattern) @@ -75,9 +73,9 @@ func SubmitOriginal(plain string, reg *regexp.Regexp) { } func FindOriginal(plain string) (*regexp.Regexp, bool) { - item := cache.Get(plain) - if item != nil { - return item.Value(), true + item, ok := cache.Get(plain) + if ok { + return item, true } return nil, false diff --git a/pkg/timecache/timecache.go b/pkg/timecache/timecache.go new file mode 100644 index 0000000..d9b1a8b --- /dev/null +++ b/pkg/timecache/timecache.go @@ -0,0 +1,74 @@ +package timecache + +import ( + "sync" + "time" +) + +type Cache struct { + m sync.RWMutex + t time.Time + o Options +} + +type Options struct { + round time.Duration +} + +func New(o Options) *Cache { + c := Cache{ + o: o, + } + + return &c +} + +func (t *Cache) Now() time.Time { + t.m.RLock() + if !t.t.IsZero() { + defer t.m.RUnlock() + return t.t + } + + t.m.RUnlock() + return t.update() +} + +func (t *Cache) update() time.Time { + t.m.Lock() + defer t.m.Unlock() + if !t.t.IsZero() { + return t.t + } + + var d time.Duration + if t.o.round > time.Nanosecond { + d = t.o.round + } else { + d = time.Second * 1 + } + + t.t = time.Now().Round(d) + + go func(duration time.Duration) { + if t.o.round > time.Nanosecond { + duration = t.o.round / 2 + } + + time.Sleep(duration) + t.reset() + }(d) + + return t.t +} + +func (t *Cache) reset() { + t.m.Lock() + defer t.m.Unlock() + t.t = time.Time{} +} + +func (o Options) Round(d time.Duration) Options { + o.round = d + return o +} diff --git a/pkg/timecache/timecache_test.go b/pkg/timecache/timecache_test.go new file mode 100644 index 0000000..ee7e6af --- /dev/null +++ b/pkg/timecache/timecache_test.go @@ -0,0 +1,49 @@ +package timecache + +import ( + "testing" + "time" +) + +func TestTime(t *testing.T) { + t.Parallel() + tc := (&Cache{}).Now() + if tc.IsZero() { + t.Fatalf("time is zero") + } +} + +func TestRounding(t *testing.T) { + t.Parallel() + ti := New(Options{}.Round(time.Minute * 5)).Now() + + if ti.Minute()%5 != 0 { + t.Fatalf("time is not a 5 multiple") + } +} + +func TestResolution(t *testing.T) { + t.Parallel() + const magicNumber = 3 + const rounds = 700 + ti := New(Options{}.Round(time.Millisecond * magicNumber)) + + unique := 0 + old := ti.Now().UnixMilli() + for i := 0; i < rounds; i++ { + new := ti.Now().UnixMilli() + if new > old { + unique++ + old = new + } + + if div := new % magicNumber; div != 0 { + t.Fatalf("not a multiple of %d: %d", magicNumber, div) + } + time.Sleep(time.Millisecond * 1) + } + + if unique < rounds/magicNumber-1 { + t.Fatalf("not enough resolution rounds %d", unique) + } +} diff --git a/pkg/ttlcache/domain.go b/pkg/ttlcache/domain.go new file mode 100644 index 0000000..b324a3f --- /dev/null +++ b/pkg/ttlcache/domain.go @@ -0,0 +1,41 @@ +package ttlcache + +import ( + "sync" + "time" + + "github.com/autobrr/autobrr/pkg/timecache" +) + +const NoTTL time.Duration = 0 +const DefaultTTL time.Duration = time.Nanosecond * 1 + +type Cache[K comparable, V any] struct { + tc timecache.Cache + l sync.RWMutex + o Options[K, V] + ch chan time.Time + m map[K]Item[V] +} + +type Item[V any] struct { + t time.Time + d time.Duration + v V +} + +type Options[K comparable, V any] struct { + defaultTTL time.Duration + defaultResolution time.Duration + deallocationFunc DeallocationFunc[K, V] + noUpdateTime bool +} + +type DeallocationReason int + +const ( + ReasonTimedOut = DeallocationReason(iota) + ReasonDeleted = DeallocationReason(iota) +) + +type DeallocationFunc[K comparable, V any] func(key K, value V, reason DeallocationReason) diff --git a/pkg/ttlcache/expiration.go b/pkg/ttlcache/expiration.go new file mode 100644 index 0000000..e1864c0 --- /dev/null +++ b/pkg/ttlcache/expiration.go @@ -0,0 +1,76 @@ +package ttlcache + +import ( + "time" +) + +func (c *Cache[K, V]) startExpirations() { + timer := time.NewTimer(1 * time.Second) + stopTimer(timer) // wasteful, but makes the loop cleaner because this is initialized. + defer stopTimer(timer) + + var timeSleep time.Time + for { + select { + case t, ok := <-c.ch: + if !ok { + return + } else if t.IsZero() { + continue + } + + if timeSleep.IsZero() || timeSleep.After(t) { + timeSleep = t + restartTimer(timer, timeSleep.Sub(c.tc.Now())) + } + + case <-timer.C: + stopTimer(timer) + c.expire() + timeSleep = time.Time{} + } + } +} + +func restartTimer(t *time.Timer, d time.Duration) { + stopTimer(t) + t.Reset(d) +} + +func stopTimer(t *time.Timer) { + t.Stop() + + // go < 1.23 returns stale values on expired timers. + if len(t.C) != 0 { + <-t.C + } +} + +func (c *Cache[K, V]) expire() { + t := c.tc.Now() + var soon time.Time + + c.l.Lock() + defer c.l.Unlock() + for k, v := range c.m { + if v.t.IsZero() { + continue + } else if v.t.After(t) { + if soon.IsZero() || soon.After(v.t) { + soon = v.t + } + continue + } + + c.deleteUnsafe(k, v, ReasonTimedOut) + } + + if !soon.IsZero() { // wake-up feedback loop + go func(s time.Time) { // we need to release the lock, if the input pipeline has exceeded the wakeup budget. + defer func() { + _ = recover() // if the channel is closed, this doesn't matter on shutdown because this is expected. + }() + c.ch <- s + }(soon) + } +} diff --git a/pkg/ttlcache/internal.go b/pkg/ttlcache/internal.go new file mode 100644 index 0000000..8b6dbd9 --- /dev/null +++ b/pkg/ttlcache/internal.go @@ -0,0 +1,111 @@ +package ttlcache + +import "time" + +func (c *Cache[K, V]) get(key K) (Item[V], bool) { + c.l.RLock() + defer c.l.RUnlock() + return c._g(key) +} + +func (c *Cache[K, V]) _g(key K) (Item[V], bool) { + v, ok := c.m[key] + if !ok { + return v, ok + } + + return v, ok +} + +func (c *Cache[K, V]) set(key K, it Item[V]) Item[V] { + c.l.Lock() + defer c.l.Unlock() + return c._s(key, it) +} + +func (c *Cache[K, V]) _s(key K, it Item[V]) Item[V] { + it.d, it.t = c.getDuration(it.d) + c.m[key] = it + c.ch <- it.t + return it +} + +func (c *Cache[K, V]) getOrSet(key K, it Item[V]) (Item[V], bool) { + c.l.Lock() + defer c.l.Unlock() + return c._gos(key, it) +} + +func (c *Cache[K, V]) _gos(key K, it Item[V]) (Item[V], bool) { + if g, ok := c._g(key); ok { + return g, ok + } + + return c._s(key, it), true +} + +func (c *Cache[K, V]) delete(key K, reason DeallocationReason) { + var v Item[V] + c.l.Lock() + defer c.l.Unlock() + + if c.o.deallocationFunc != nil { + var ok bool + v, ok = c.m[key] + if !ok { + return + } + } + + c.deleteUnsafe(key, v, reason) +} + +func (c *Cache[K, V]) deleteUnsafe(key K, v Item[V], reason DeallocationReason) { + delete(c.m, key) + + if c.o.deallocationFunc != nil { + c.o.deallocationFunc(key, v.v, reason) + } +} + +func (c *Cache[K, V]) getkeys() []K { + c.l.RLock() + defer c.l.RUnlock() + + keys := make([]K, len(c.m)) + for k := range c.m { + keys = append(keys, k) + } + + return keys +} + +func (c *Cache[K, V]) close() { + c.l.Lock() + defer c.l.Unlock() + close(c.ch) +} + +func (c *Cache[K, V]) getDuration(d time.Duration) (time.Duration, time.Time) { + switch d { + case NoTTL: + case DefaultTTL: + return c.o.defaultTTL, c.tc.Now().Add(c.o.defaultTTL) + default: + return d, c.tc.Now().Add(d) + } + + return NoTTL, time.Time{} +} + +func (i *Item[V]) getDuration() time.Duration { + return i.d +} + +func (i *Item[V]) getTime() time.Time { + return i.t +} + +func (i *Item[V]) getValue() V { + return i.v +} diff --git a/pkg/ttlcache/ttlcache.go b/pkg/ttlcache/ttlcache.go new file mode 100644 index 0000000..4f7e095 --- /dev/null +++ b/pkg/ttlcache/ttlcache.go @@ -0,0 +1,127 @@ +package ttlcache + +import ( + "time" + + "github.com/autobrr/autobrr/pkg/timecache" +) + +func New[K comparable, V any](options Options[K, V]) *Cache[K, V] { + c := Cache[K, V]{ + o: options, + ch: make(chan time.Time, 1000), + m: make(map[K]Item[V]), + } + + if options.defaultTTL != NoTTL && options.defaultResolution == 0 { + c.tc = *timecache.New(timecache.Options{}.Round(options.defaultTTL / 2)) + } else if options.defaultResolution != 0 { + c.tc = *timecache.New(timecache.Options{}.Round(options.defaultResolution)) + } + + go c.startExpirations() + return &c +} + +func (c *Cache[K, V]) Get(key K) (V, bool) { + it, ok := c.GetItem(key) + if !ok { + return *new(V), ok + } + + return it.GetValue(), ok +} + +func (c *Cache[K, V]) GetItem(key K) (Item[V], bool) { + it, ok := c.get(key) + if !ok { + return it, ok + } + + if !c.o.noUpdateTime && !it.t.IsZero() { + if _, t := c.getDuration(it.d); t.After(it.t) { + c.set(key, it) + } + } + + return it, ok +} + +func (c *Cache[K, V]) GetOrSet(key K, value V, duration time.Duration) (V, bool) { + it, ok := c.GetOrSetItem(key, value, duration) + if !ok { + return *new(V), ok + } + + return it.GetValue(), ok +} + +func (c *Cache[K, V]) fixupDuration(duration time.Duration) time.Duration { + if c.o.defaultTTL == NoTTL && duration == DefaultTTL { + return NoTTL + } + + return duration +} + +func (c *Cache[K, V]) GetOrSetItem(key K, value V, duration time.Duration) (Item[V], bool) { + it, ok := c.getOrSet(key, Item[V]{v: value, d: c.fixupDuration(duration)}) + if !ok { + return Item[V]{}, ok + } + + return it, ok +} + +func (c *Cache[K, V]) Set(key K, value V, duration time.Duration) bool { + c.SetItem(key, value, duration) + return true +} + +func (c *Cache[K, V]) SetItem(key K, value V, duration time.Duration) Item[V] { + return c.set(key, Item[V]{v: value, d: c.fixupDuration(duration)}) +} + +func (c *Cache[K, V]) Delete(key K) { + c.delete(key, ReasonDeleted) +} + +func (c *Cache[K, V]) GetKeys() []K { + return c.getkeys() +} + +func (c *Cache[K, V]) Close() { + c.close() +} + +func (i *Item[V]) GetDuration() time.Duration { + return i.getDuration() +} + +func (i *Item[V]) GetTime() time.Time { + return i.getTime() +} + +func (i *Item[V]) GetValue() V { + return i.getValue() +} + +func (o Options[K, V]) SetTimerResolution(d time.Duration) Options[K, V] { + o.defaultResolution = d + return o +} + +func (o Options[K, V]) SetDefaultTTL(d time.Duration) Options[K, V] { + o.defaultTTL = d + return o +} + +func (o Options[K, V]) SetDeallocationFunc(f DeallocationFunc[K, V]) Options[K, V] { + o.deallocationFunc = f + return o +} + +func (o Options[K, V]) DisableUpdateTime(val bool) Options[K, V] { + o.noUpdateTime = val + return o +} diff --git a/pkg/ttlcache/ttlcache_test.go b/pkg/ttlcache/ttlcache_test.go new file mode 100644 index 0000000..0ea5188 --- /dev/null +++ b/pkg/ttlcache/ttlcache_test.go @@ -0,0 +1,245 @@ +package ttlcache + +import ( + "testing" + "time" +) + +func TestGet(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(1 * time.Second)) + defer c.Close() + + for i := 0; i < 10; i++ { + c.Set(i, true, DefaultTTL) + } + + for i := 0; i < 10; i++ { + val, ok := c.Get(i) + if !ok { + t.Fatalf("missing key: %d", i) + } else if !val { + t.Fatalf("bad value on key: %d", i) + } + } +} + +func TestExpirations(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(200 * time.Millisecond)) + defer c.Close() + for i := 0; i < 10; i++ { + c.Set(i, true, DefaultTTL) + } + + time.Sleep(1 * time.Second) + + for i := 0; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestSwaps(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(200 * time.Millisecond)) + defer c.Close() + for i := 0; i < 10; i++ { + c.Set(i, true, DefaultTTL) + } + + time.Sleep(1 * time.Second) + for i := 0; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } + + for i := 10; i < 20; i++ { + c.Set(i, true, DefaultTTL) + if _, ok := c.Get(i); !ok { + t.Fatalf("missing key: %d", i) + } + } +} + +func TestRetimer(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(200 * time.Millisecond)) + defer c.Close() + for i := 1; i < 10; i++ { + c.Set(i, true, time.Duration(10-i)*100*time.Millisecond) + } + + time.Sleep(2 * time.Second) + for i := 1; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestSchedule(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(1 * time.Second)) + defer c.Close() + for i := 1; i < 10; i++ { + c.Set(i, true, time.Duration(i)*100*time.Millisecond) + } + + time.Sleep(3 * time.Second) + for i := 1; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestInterlace(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(100 * time.Millisecond)) + defer c.Close() + swap := false + for i := 0; i < 10; i++ { + swap = !swap + ttl := DefaultTTL + if swap { + ttl = NoTTL + } + c.Set(i, true, ttl) + } + + time.Sleep(1 * time.Second) + swap = false + for i := 0; i < 10; i++ { + swap = !swap + if !swap { + continue + } + + if _, ok := c.Get(i); !ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestReschedule(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(100 * time.Millisecond)) + defer c.Close() + for i := 1; i < 10; i++ { + c.Set(i, true, NoTTL) + c.Set(i, true, DefaultTTL) + } + + time.Sleep(1 * time.Second) + for i := 1; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestRescheduleNoTTL(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(100 * time.Millisecond)) + defer c.Close() + for i := 1; i < 10; i++ { + c.Set(i, true, DefaultTTL) + c.Set(i, true, NoTTL) + } + + time.Sleep(1 * time.Second) + for i := 1; i < 10; i++ { + if _, ok := c.Get(i); !ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestDelete(t *testing.T) { + t.Parallel() + c := New[int, bool](Options[int, bool]{}.SetDefaultTTL(100 * time.Millisecond)) + defer c.Close() + for i := 1; i < 10; i++ { + c.Set(i, true, NoTTL) + c.Delete(i) + } + + for i := 1; i < 10; i++ { + if _, ok := c.Get(i); ok { + t.Fatalf("found key: %d", i) + } + } +} + +func TestDeallocationTimeout(t *testing.T) { + t.Parallel() + hit := false + o := Options[int, bool]{}. + SetDefaultTTL(time.Millisecond * 100). + SetDeallocationFunc(func(key int, value bool, reason DeallocationReason) { hit = reason == ReasonTimedOut }) + + c := New[int, bool](o) + defer c.Close() + + for i := 0; i < 1; i++ { + c.Set(i, true, DefaultTTL) + } + + time.Sleep(3 * time.Second) + if !hit { + t.Fatalf("Deallocation not hit.") + } +} + +func TestDeallocationDeleted(t *testing.T) { + t.Parallel() + hit := false + o := Options[int, bool]{}. + SetDefaultTTL(time.Millisecond * 100). + SetDeallocationFunc(func(key int, value bool, reason DeallocationReason) { hit = reason == ReasonDeleted }) + + c := New[int, bool](o) + defer c.Close() + + for i := 0; i < 1; i++ { + c.Set(i, true, DefaultTTL) + c.Delete(i) + } + + if !hit { + t.Fatalf("Deallocation not hit.") + } +} + +func TestTimerReset(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) + defer close(ch) + + c := New[int, bool](Options[int, bool]{}. + SetDefaultTTL(time.Millisecond * 100). + SetDeallocationFunc(func(key int, value bool, reason DeallocationReason) { ch <- struct{}{} })) + + defer c.Close() + + const base = 0 + const rounds = 1 + for i := base; i < rounds; i++ { + c.Set(i, true, DefaultTTL) + } + + for i := base; i < rounds; i++ { + <-ch + } + + for i := 0; i < 1; i++ { + c.Set(i, true, DefaultTTL) + } + + for i := base; i < rounds; i++ { + <-ch + } +}