1import threading
2import time
3from collections import OrderedDict
4from typing import Optional, Generic, TypeVar, Iterator, Dict, Any, List
5
6from mongoose.core import SingletonMeta
7
8K = TypeVar("K")
9V = TypeVar("V")
10
11
[docs]
12class Cache(Generic[K, V], metaclass=SingletonMeta):
13 """
14 Sharded LRU cache with optional TTL implemented as a singleton per class.
15
16 The cache partitions keys into a fixed number of shards. Each shard
17 maintains its own OrderedDict and lock so operations for different
18 shards can proceed concurrently without contending on a single global
19 lock. Each shard enforces its own capacity (approximately
20 max_size / num_shards), so the global max_size is a soft bound but
21 should be close when keys distribute evenly.
22
23 This design provides better concurrency for high-contention workloads
24 while preserving LRU semantics within each shard.
25 """
26
[docs]
27 def __init__(self, max_size: int = 1024, ttl_seconds: Optional[float] = None, num_shards: int = 1) -> None:
28 """
29 Initialize the sharded cache.
30
31 Args:
32 max_size: Global maximum number of entries (positive int).
33 ttl_seconds: Optional TTL in seconds for entries. If None, no time expiry.
34 num_shards: Number of independent shards to partition the keyspace.
35 """
36 # Guard against re-initialization when singleton returns an
37 # already-created instance.
38 if hasattr(self, "_initialized") and self._initialized:
39 return
40
41 if not isinstance(max_size, int) or max_size <= 0:
42 raise ValueError("max_size must be a positive integer")
43 if ttl_seconds is not None and (not isinstance(ttl_seconds, (int, float)) or ttl_seconds <= 0):
44 raise ValueError("ttl_seconds must be a positive number or None")
45 if not isinstance(num_shards, int) or num_shards <= 0:
46 raise ValueError("num_shards must be a positive integer")
47
48 self.max_size: int = max_size
49 self._ttl_seconds: Optional[float] = float(ttl_seconds) if ttl_seconds is not None else None
50
51 # Ensure we don't create more shards than max_size (each shard should
52 # have at least one slot when possible). Reducing shards avoids many
53 # shards having zero capacity which would cause immediate eviction.
54 self._num_shards: int = min(num_shards, max_size)
55
56 # Compute per-shard max size; ensure at least 1 per shard when possible
57 base = max_size // self._num_shards
58 remainder = max_size % self._num_shards
59 self._shard_max_sizes = [base + (1 if i < remainder else 0) for i in range(self._num_shards)]
60
61 # Per-shard stores and locks
62 self._shards: List[Dict[str, Any]] = [] # each shard: {'lock': RLock, 'store': OrderedDict}
63 for i in range(self._num_shards):
64 shard = {"lock": threading.RLock(), "store": OrderedDict()}
65 self._shards.append(shard)
66
67 # Stats counters for observability
68 self._stats_lock = threading.Lock()
69 self._hits: int = 0
70 self._misses: int = 0
71 self._evictions: int = 0
72
73 self._initialized = True
74
75 def _now(self) -> float:
76 """Return current monotonic time."""
77 return time.monotonic()
78
79 def _expires_at(self) -> Optional[float]:
80 """Return expiry timestamp for a new entry, or None if no TTL."""
81 if self._ttl_seconds is None:
82 return None
83 return self._now() + self._ttl_seconds
84
85 def _is_expired(self, expires_at: Optional[float]) -> bool:
86 """Return True if the given expires_at timestamp is in the past."""
87 return expires_at is not None and expires_at <= self._now()
88
89 def _record_hit(self) -> None:
90 with self._stats_lock:
91 self._hits += 1
92
93 def _record_miss(self) -> None:
94 with self._stats_lock:
95 self._misses += 1
96
97 def _record_eviction(self) -> None:
98 with self._stats_lock:
99 self._evictions += 1
100
[docs]
101 def get_stats(self) -> Dict[str, int]:
102 """Return current stats snapshot: hits, misses, evictions."""
103 with self._stats_lock:
104 return {"hits": self._hits, "misses": self._misses, "evictions": self._evictions}
105
[docs]
106 def reset_stats(self) -> None:
107 """Zero all statistics."""
108 with self._stats_lock:
109 self._hits = 0
110 self._misses = 0
111 self._evictions = 0
112
113 def _shard_for_key(self, key: K) -> int:
114 """Return the shard index for a given key."""
115 return (hash(key) & 0x7FFFFFFF) % self._num_shards
116
117 def _purge_expired_in_shard(self, shard_idx: int) -> None:
118 """Purge expired entries from a single shard. Caller must hold shard lock."""
119 if self._ttl_seconds is None:
120 return
121 now = self._now()
122 store: OrderedDict = self._shards[shard_idx]["store"]
123 remove = [k for k, (_v, expires_at) in store.items() if expires_at is not None and expires_at <= now]
124 for k in remove:
125 store.pop(k, None)
126
[docs]
127 def set(self, key: K, value: V) -> None:
128 """
129 Insert or update a key in its shard and enforce shard-level capacity.
130 """
131 shard_idx = self._shard_for_key(key)
132 shard = self._shards[shard_idx]
133 lock: threading.RLock = shard["lock"]
134 with lock:
135 store: OrderedDict = shard["store"]
136 # Purge expired entries from this shard
137 self._purge_expired_in_shard(shard_idx)
138
139 if key in store:
140 store.move_to_end(key)
141
142 expires_at = self._expires_at()
143 store[key] = (value, expires_at)
144
145 # Evict LRU in this shard while over capacity
146 shard_max = self._shard_max_sizes[shard_idx]
147 while len(store) > shard_max:
148 store.popitem(last=False)
149 self._record_eviction()
150
[docs]
151 def get(self, key: K) -> Optional[V]:
152 """
153 Retrieve a key from its shard and mark it as recently used.
154
155 Returns None if missing or expired.
156 """
157 shard_idx = self._shard_for_key(key)
158 shard = self._shards[shard_idx]
159 lock: threading.RLock = shard["lock"]
160 with lock:
161 store: OrderedDict = shard["store"]
162 # Purge expired entries in shard
163 self._purge_expired_in_shard(shard_idx)
164
165 try:
166 val, expires_at = store.pop(key)
167 except KeyError:
168 self._record_miss()
169 return None
170
171 if self._is_expired(expires_at):
172 self._record_miss()
173 return None
174
175 # Re-insert to mark MRU in this shard
176 store[key] = (val, expires_at)
177 self._record_hit()
178 return val
179
[docs]
180 def __len__(self) -> int:
181 """Return the total number of non-expired items across all shards."""
182 total = 0
183 # Acquire shard locks in order to avoid deadlocks
184 for i in range(self._num_shards):
185 shard = self._shards[i]
186 with shard["lock"]:
187 self._purge_expired_in_shard(i)
188 total += len(shard["store"])
189 return total
190
[docs]
191 def clear(self) -> None:
192 """Clear all shards."""
193 for shard in self._shards:
194 with shard["lock"]:
195 shard["store"].clear()
196
[docs]
197 def __contains__(self, key: K) -> bool:
198 """True if key exists and is not expired."""
199 shard_idx = self._shard_for_key(key)
200 shard = self._shards[shard_idx]
201 with shard["lock"]:
202 self._purge_expired_in_shard(shard_idx)
203 return key in shard["store"]
204
[docs]
205 def items(self) -> Iterator[tuple[K, V]]:
206 """Yield (key, value) pairs across shards (LRU within each shard)."""
207 # Snapshot items from each shard while holding its lock
208 snapshots = []
209 for i in range(self._num_shards):
210 shard = self._shards[i]
211 with shard["lock"]:
212 self._purge_expired_in_shard(i)
213 snapshots.extend([(k, v) for k, (v, _e) in shard["store"].items()])
214 for k, v in snapshots:
215 yield (k, v)
216
217
[docs]
218class SeverityCache(Cache[str, int]):
219 """
220 Specialization of Cache for community_id -> severity mappings.
221
222 Provides convenience methods `set_severity` and `get_severity` that
223 validate input types and keep the simple, explicit API used by the
224 rest of the codebase.
225 """
226
[docs]
227 def set_severity(self, community_id: str, severity: int) -> None:
228 """
229 Store severity for a community_id. Always keep the highest severity
230 value if the key already exists.
231
232 This method performs the comparison and update under the shard
233 lock to avoid extra cache hits and to ensure correctness under
234 concurrent updates.
235
236 Args:
237 community_id: String identifier for the community.
238 severity: Integer severity value.
239
240 Raises:
241 TypeError: If arguments are of incorrect types.
242 """
243 if not isinstance(community_id, str):
244 raise TypeError("community_id must be a str")
245 if not isinstance(severity, int):
246 raise TypeError("severity must be an int")
247
248 shard_idx = self._shard_for_key(community_id)
249 shard = self._shards[shard_idx]
250 lock: threading.RLock = shard["lock"]
251 with lock:
252 store: OrderedDict = shard["store"]
253 # Purge expired entries in this shard first
254 self._purge_expired_in_shard(shard_idx)
255
256 existing = store.get(community_id)
257 if existing is not None:
258 existing_val, existing_expires = existing
259 if existing_expires is not None and self._is_expired(existing_expires):
260 # Treat as missing
261 chosen = severity
262 else:
263 chosen = max(existing_val, severity)
264 # Move to MRU and set chosen value
265 if community_id in store:
266 store.move_to_end(community_id)
267 expires_at = self._expires_at()
268 store[community_id] = (chosen, expires_at)
269 else:
270 expires_at = self._expires_at()
271 store[community_id] = (severity, expires_at)
272
273 # Evict LRU in this shard while over capacity
274 shard_max = self._shard_max_sizes[shard_idx]
275 while len(store) > shard_max:
276 store.popitem(last=False)
277 self._record_eviction()
278
[docs]
279 def get_severity(self, community_id: str) -> Optional[int]:
280 """
281 Retrieve the cached severity for a community_id or None if missing.
282
283 Args:
284 community_id: String identifier for the community.
285
286 Returns:
287 The severity integer if present, otherwise None.
288
289 Raises:
290 TypeError: If `community_id` is not a str.
291 """
292 if not isinstance(community_id, str):
293 raise TypeError("community_id must be a str")
294
295 return self.get(community_id)
296
297
[docs]
298def reset_singletons() -> None:
299 """
300 Clear the internal singleton instance registry.
301
302 This is primarily intended for use in unit tests so different tests
303 can instantiate singletons with different constructor parameters and
304 get fresh instances. Use with care in production code as clearing
305 singletons while other threads hold references can be unsafe.
306 """
307 with SingletonMeta._lock:
308 SingletonMeta._instances.clear()