1# SPDX-FileCopyrightText: 2026 Defensive Lab Agency
2# SPDX-FileContributor: u039b <git@0x39b.fr>
3#
4# SPDX-License-Identifier: GPL-3.0-or-later
5
6import threading
7import time
8from collections import OrderedDict
9from typing import Optional, Generic, TypeVar, Iterator, Dict, Any, List
10
11from mongoose.core import SingletonMeta
12
13K = TypeVar("K")
14V = TypeVar("V")
15
16
[docs]
17class Cache(Generic[K, V], metaclass=SingletonMeta):
18 """
19 Sharded LRU cache with optional TTL implemented as a singleton per class.
20
21 The cache partitions keys into a fixed number of shards. Each shard
22 maintains its own OrderedDict and lock so operations for different
23 shards can proceed concurrently without contending on a single global
24 lock. Each shard enforces its own capacity (approximately
25 max_size / num_shards), so the global max_size is a soft bound but
26 should be close when keys distribute evenly.
27
28 This design provides better concurrency for high-contention workloads
29 while preserving LRU semantics within each shard.
30 """
31
[docs]
32 def __init__(self, max_size: int = 1024, ttl_seconds: Optional[float] = None, num_shards: int = 1) -> None:
33 """
34 Initialize the sharded cache.
35
36 Args:
37 max_size: Global maximum number of entries (positive int).
38 ttl_seconds: Optional TTL in seconds for entries. If None, no time expiry.
39 num_shards: Number of independent shards to partition the keyspace.
40 """
41 # Guard against re-initialization when singleton returns an
42 # already-created instance.
43 if hasattr(self, "_initialized") and self._initialized:
44 return
45
46 if not isinstance(max_size, int) or max_size <= 0:
47 raise ValueError("max_size must be a positive integer")
48 if ttl_seconds is not None and (not isinstance(ttl_seconds, (int, float)) or ttl_seconds <= 0):
49 raise ValueError("ttl_seconds must be a positive number or None")
50 if not isinstance(num_shards, int) or num_shards <= 0:
51 raise ValueError("num_shards must be a positive integer")
52
53 self.max_size: int = max_size
54 self._ttl_seconds: Optional[float] = float(ttl_seconds) if ttl_seconds is not None else None
55
56 # Ensure we don't create more shards than max_size (each shard should
57 # have at least one slot when possible). Reducing shards avoids many
58 # shards having zero capacity which would cause immediate eviction.
59 self._num_shards: int = min(num_shards, max_size)
60
61 # Compute per-shard max size; ensure at least 1 per shard when possible
62 base = max_size // self._num_shards
63 remainder = max_size % self._num_shards
64 self._shard_max_sizes = [base + (1 if i < remainder else 0) for i in range(self._num_shards)]
65
66 # Per-shard stores and locks
67 self._shards: List[Dict[str, Any]] = [] # each shard: {'lock': RLock, 'store': OrderedDict}
68 for i in range(self._num_shards):
69 shard = {"lock": threading.RLock(), "store": OrderedDict()}
70 self._shards.append(shard)
71
72 # Stats counters for observability
73 self._stats_lock = threading.Lock()
74 self._hits: int = 0
75 self._misses: int = 0
76 self._evictions: int = 0
77
78 self._initialized = True
79
80 def _now(self) -> float:
81 """Return current monotonic time."""
82 return time.monotonic()
83
84 def _expires_at(self) -> Optional[float]:
85 """Return expiry timestamp for a new entry, or None if no TTL."""
86 if self._ttl_seconds is None:
87 return None
88 return self._now() + self._ttl_seconds
89
90 def _is_expired(self, expires_at: Optional[float]) -> bool:
91 """Return True if the given expires_at timestamp is in the past."""
92 return expires_at is not None and expires_at <= self._now()
93
94 def _record_hit(self) -> None:
95 with self._stats_lock:
96 self._hits += 1
97
98 def _record_miss(self) -> None:
99 with self._stats_lock:
100 self._misses += 1
101
102 def _record_eviction(self) -> None:
103 with self._stats_lock:
104 self._evictions += 1
105
[docs]
106 def get_stats(self) -> Dict[str, int]:
107 """Return current stats snapshot: hits, misses, evictions."""
108 with self._stats_lock:
109 return {"hits": self._hits, "misses": self._misses, "evictions": self._evictions}
110
[docs]
111 def reset_stats(self) -> None:
112 """Zero all statistics."""
113 with self._stats_lock:
114 self._hits = 0
115 self._misses = 0
116 self._evictions = 0
117
118 def _shard_for_key(self, key: K) -> int:
119 """Return the shard index for a given key."""
120 return (hash(key) & 0x7FFFFFFF) % self._num_shards
121
122 def _purge_expired_in_shard(self, shard_idx: int) -> None:
123 """Purge expired entries from a single shard. Caller must hold shard lock."""
124 if self._ttl_seconds is None:
125 return
126 now = self._now()
127 store: OrderedDict = self._shards[shard_idx]["store"]
128 remove = [k for k, (_v, expires_at) in store.items() if expires_at is not None and expires_at <= now]
129 for k in remove:
130 store.pop(k, None)
131
[docs]
132 def set(self, key: K, value: V) -> None:
133 """
134 Insert or update a key in its shard and enforce shard-level capacity.
135 """
136 shard_idx = self._shard_for_key(key)
137 shard = self._shards[shard_idx]
138 lock: threading.RLock = shard["lock"]
139 with lock:
140 store: OrderedDict = shard["store"]
141 # Purge expired entries from this shard
142 self._purge_expired_in_shard(shard_idx)
143
144 if key in store:
145 store.move_to_end(key)
146
147 expires_at = self._expires_at()
148 store[key] = (value, expires_at)
149
150 # Evict LRU in this shard while over capacity
151 shard_max = self._shard_max_sizes[shard_idx]
152 while len(store) > shard_max:
153 store.popitem(last=False)
154 self._record_eviction()
155
[docs]
156 def get(self, key: K) -> Optional[V]:
157 """
158 Retrieve a key from its shard and mark it as recently used.
159
160 Returns None if missing or expired.
161 """
162 shard_idx = self._shard_for_key(key)
163 shard = self._shards[shard_idx]
164 lock: threading.RLock = shard["lock"]
165 with lock:
166 store: OrderedDict = shard["store"]
167 # Purge expired entries in shard
168 self._purge_expired_in_shard(shard_idx)
169
170 try:
171 val, expires_at = store.pop(key)
172 except KeyError:
173 self._record_miss()
174 return None
175
176 if self._is_expired(expires_at):
177 self._record_miss()
178 return None
179
180 # Re-insert to mark MRU in this shard
181 store[key] = (val, expires_at)
182 self._record_hit()
183 return val
184
[docs]
185 def __len__(self) -> int:
186 """Return the total number of non-expired items across all shards."""
187 total = 0
188 # Acquire shard locks in order to avoid deadlocks
189 for i in range(self._num_shards):
190 shard = self._shards[i]
191 with shard["lock"]:
192 self._purge_expired_in_shard(i)
193 total += len(shard["store"])
194 return total
195
[docs]
196 def clear(self) -> None:
197 """Clear all shards."""
198 for shard in self._shards:
199 with shard["lock"]:
200 shard["store"].clear()
201
[docs]
202 def __contains__(self, key: K) -> bool:
203 """True if key exists and is not expired."""
204 shard_idx = self._shard_for_key(key)
205 shard = self._shards[shard_idx]
206 with shard["lock"]:
207 self._purge_expired_in_shard(shard_idx)
208 return key in shard["store"]
209
[docs]
210 def items(self) -> Iterator[tuple[K, V]]:
211 """Yield (key, value) pairs across shards (LRU within each shard)."""
212 # Snapshot items from each shard while holding its lock
213 snapshots = []
214 for i in range(self._num_shards):
215 shard = self._shards[i]
216 with shard["lock"]:
217 self._purge_expired_in_shard(i)
218 snapshots.extend([(k, v) for k, (v, _e) in shard["store"].items()])
219 for k, v in snapshots:
220 yield (k, v)
221
222
[docs]
223class SeverityCache(Cache[str, int]):
224 """
225 Specialization of Cache for community_id -> severity mappings.
226
227 Provides convenience methods `set_severity` and `get_severity` that
228 validate input types and keep the simple, explicit API used by the
229 rest of the codebase.
230 """
231
[docs]
232 def set_severity(self, community_id: str, severity: int) -> None:
233 """
234 Store severity for a community_id. Always keep the highest severity
235 value if the key already exists.
236
237 This method performs the comparison and update under the shard
238 lock to avoid extra cache hits and to ensure correctness under
239 concurrent updates.
240
241 Args:
242 community_id: String identifier for the community.
243 severity: Integer severity value.
244
245 Raises:
246 TypeError: If arguments are of incorrect types.
247 """
248 if not isinstance(community_id, str):
249 raise TypeError("community_id must be a str")
250 if not isinstance(severity, int):
251 raise TypeError("severity must be an int")
252
253 shard_idx = self._shard_for_key(community_id)
254 shard = self._shards[shard_idx]
255 lock: threading.RLock = shard["lock"]
256 with lock:
257 store: OrderedDict = shard["store"]
258 # Purge expired entries in this shard first
259 self._purge_expired_in_shard(shard_idx)
260
261 existing = store.get(community_id)
262 if existing is not None:
263 existing_val, existing_expires = existing
264 if existing_expires is not None and self._is_expired(existing_expires):
265 # Treat as missing
266 chosen = severity
267 else:
268 chosen = max(existing_val, severity)
269 # Move to MRU and set chosen value
270 if community_id in store:
271 store.move_to_end(community_id)
272 expires_at = self._expires_at()
273 store[community_id] = (chosen, expires_at)
274 else:
275 expires_at = self._expires_at()
276 store[community_id] = (severity, expires_at)
277
278 # Evict LRU in this shard while over capacity
279 shard_max = self._shard_max_sizes[shard_idx]
280 while len(store) > shard_max:
281 store.popitem(last=False)
282 self._record_eviction()
283
[docs]
284 def get_severity(self, community_id: str) -> Optional[int]:
285 """
286 Retrieve the cached severity for a community_id or None if missing.
287
288 Args:
289 community_id: String identifier for the community.
290
291 Returns:
292 The severity integer if present, otherwise None.
293
294 Raises:
295 TypeError: If `community_id` is not a str.
296 """
297 if not isinstance(community_id, str):
298 raise TypeError("community_id must be a str")
299
300 return self.get(community_id)
301
302
[docs]
303def reset_singletons() -> None:
304 """
305 Clear the internal singleton instance registry.
306
307 This is primarily intended for use in unit tests so different tests
308 can instantiate singletons with different constructor parameters and
309 get fresh instances. Use with care in production code as clearing
310 singletons while other threads hold references can be unsafe.
311 """
312 with SingletonMeta._lock:
313 SingletonMeta._instances.clear()