1from typing import Optional, Generic, TypeVar, Iterator, Dict, Any, Tuple, List
2from collections import OrderedDict
3import threading
4import time
5
6K = TypeVar("K")
7V = TypeVar("V")
8
9
31
32
[docs]
33class Cache(Generic[K, V], metaclass=SingletonMeta):
34 """
35 Sharded LRU cache with optional TTL implemented as a singleton per class.
36
37 The cache partitions keys into a fixed number of shards. Each shard
38 maintains its own OrderedDict and lock so operations for different
39 shards can proceed concurrently without contending on a single global
40 lock. Each shard enforces its own capacity (approximately
41 max_size / num_shards), so the global max_size is a soft bound but
42 should be close when keys distribute evenly.
43
44 This design provides better concurrency for high-contention workloads
45 while preserving LRU semantics within each shard.
46 """
47
[docs]
48 def __init__(self, max_size: int = 1024, ttl_seconds: Optional[float] = None, num_shards: int = 1) -> None:
49 """
50 Initialize the sharded cache.
51
52 Args:
53 max_size: Global maximum number of entries (positive int).
54 ttl_seconds: Optional TTL in seconds for entries. If None, no time expiry.
55 num_shards: Number of independent shards to partition the keyspace.
56 """
57 # Guard against re-initialization when singleton returns an
58 # already-created instance.
59 if hasattr(self, "_initialized") and self._initialized:
60 return
61
62 if not isinstance(max_size, int) or max_size <= 0:
63 raise ValueError("max_size must be a positive integer")
64 if ttl_seconds is not None and (not isinstance(ttl_seconds, (int, float)) or ttl_seconds <= 0):
65 raise ValueError("ttl_seconds must be a positive number or None")
66 if not isinstance(num_shards, int) or num_shards <= 0:
67 raise ValueError("num_shards must be a positive integer")
68
69 self.max_size: int = max_size
70 self._ttl_seconds: Optional[float] = float(ttl_seconds) if ttl_seconds is not None else None
71
72 # Ensure we don't create more shards than max_size (each shard should
73 # have at least one slot when possible). Reducing shards avoids many
74 # shards having zero capacity which would cause immediate eviction.
75 self._num_shards: int = min(num_shards, max_size)
76
77 # Compute per-shard max size; ensure at least 1 per shard when possible
78 base = max_size // self._num_shards
79 remainder = max_size % self._num_shards
80 self._shard_max_sizes = [base + (1 if i < remainder else 0) for i in range(self._num_shards)]
81
82 # Per-shard stores and locks
83 self._shards: List[Dict[str, Any]] = [] # each shard: {'lock': RLock, 'store': OrderedDict}
84 for i in range(self._num_shards):
85 shard = {"lock": threading.RLock(), "store": OrderedDict()}
86 self._shards.append(shard)
87
88 # Stats counters for observability
89 self._stats_lock = threading.Lock()
90 self._hits: int = 0
91 self._misses: int = 0
92 self._evictions: int = 0
93
94 self._initialized = True
95
96 def _now(self) -> float:
97 """Return current monotonic time."""
98 return time.monotonic()
99
100 def _expires_at(self) -> Optional[float]:
101 """Return expiry timestamp for a new entry, or None if no TTL."""
102 if self._ttl_seconds is None:
103 return None
104 return self._now() + self._ttl_seconds
105
106 def _is_expired(self, expires_at: Optional[float]) -> bool:
107 """Return True if the given expires_at timestamp is in the past."""
108 return expires_at is not None and expires_at <= self._now()
109
110 def _record_hit(self) -> None:
111 with self._stats_lock:
112 self._hits += 1
113
114 def _record_miss(self) -> None:
115 with self._stats_lock:
116 self._misses += 1
117
118 def _record_eviction(self) -> None:
119 with self._stats_lock:
120 self._evictions += 1
121
[docs]
122 def get_stats(self) -> Dict[str, int]:
123 """Return current stats snapshot: hits, misses, evictions."""
124 with self._stats_lock:
125 return {"hits": self._hits, "misses": self._misses, "evictions": self._evictions}
126
[docs]
127 def reset_stats(self) -> None:
128 """Zero all statistics."""
129 with self._stats_lock:
130 self._hits = 0
131 self._misses = 0
132 self._evictions = 0
133
134 def _shard_for_key(self, key: K) -> int:
135 """Return the shard index for a given key."""
136 return (hash(key) & 0x7FFFFFFF) % self._num_shards
137
138 def _purge_expired_in_shard(self, shard_idx: int) -> None:
139 """Purge expired entries from a single shard. Caller must hold shard lock."""
140 if self._ttl_seconds is None:
141 return
142 now = self._now()
143 store: OrderedDict = self._shards[shard_idx]["store"]
144 remove = [k for k, (_v, expires_at) in store.items() if expires_at is not None and expires_at <= now]
145 for k in remove:
146 store.pop(k, None)
147
[docs]
148 def set(self, key: K, value: V) -> None:
149 """
150 Insert or update a key in its shard and enforce shard-level capacity.
151 """
152 shard_idx = self._shard_for_key(key)
153 shard = self._shards[shard_idx]
154 lock: threading.RLock = shard["lock"]
155 with lock:
156 store: OrderedDict = shard["store"]
157 # Purge expired entries from this shard
158 self._purge_expired_in_shard(shard_idx)
159
160 if key in store:
161 store.move_to_end(key)
162
163 expires_at = self._expires_at()
164 store[key] = (value, expires_at)
165
166 # Evict LRU in this shard while over capacity
167 shard_max = self._shard_max_sizes[shard_idx]
168 while len(store) > shard_max:
169 store.popitem(last=False)
170 self._record_eviction()
171
[docs]
172 def get(self, key: K) -> Optional[V]:
173 """
174 Retrieve a key from its shard and mark it as recently used.
175
176 Returns None if missing or expired.
177 """
178 shard_idx = self._shard_for_key(key)
179 shard = self._shards[shard_idx]
180 lock: threading.RLock = shard["lock"]
181 with lock:
182 store: OrderedDict = shard["store"]
183 # Purge expired entries in shard
184 self._purge_expired_in_shard(shard_idx)
185
186 try:
187 val, expires_at = store.pop(key)
188 except KeyError:
189 self._record_miss()
190 return None
191
192 if self._is_expired(expires_at):
193 self._record_miss()
194 return None
195
196 # Re-insert to mark MRU in this shard
197 store[key] = (val, expires_at)
198 self._record_hit()
199 return val
200
[docs]
201 def __len__(self) -> int:
202 """Return the total number of non-expired items across all shards."""
203 total = 0
204 # Acquire shard locks in order to avoid deadlocks
205 for i in range(self._num_shards):
206 shard = self._shards[i]
207 with shard["lock"]:
208 self._purge_expired_in_shard(i)
209 total += len(shard["store"])
210 return total
211
[docs]
212 def clear(self) -> None:
213 """Clear all shards."""
214 for shard in self._shards:
215 with shard["lock"]:
216 shard["store"].clear()
217
[docs]
218 def __contains__(self, key: K) -> bool:
219 """True if key exists and is not expired."""
220 shard_idx = self._shard_for_key(key)
221 shard = self._shards[shard_idx]
222 with shard["lock"]:
223 self._purge_expired_in_shard(shard_idx)
224 return key in shard["store"]
225
[docs]
226 def items(self) -> Iterator[tuple[K, V]]:
227 """Yield (key, value) pairs across shards (LRU within each shard)."""
228 # Snapshot items from each shard while holding its lock
229 snapshots = []
230 for i in range(self._num_shards):
231 shard = self._shards[i]
232 with shard["lock"]:
233 self._purge_expired_in_shard(i)
234 snapshots.extend([(k, v) for k, (v, _e) in shard["store"].items()])
235 for k, v in snapshots:
236 yield (k, v)
237
238
[docs]
239class SeverityCache(Cache[str, int]):
240 """
241 Specialization of Cache for community_id -> severity mappings.
242
243 Provides convenience methods `set_severity` and `get_severity` that
244 validate input types and keep the simple, explicit API used by the
245 rest of the codebase.
246 """
247
[docs]
248 def set_severity(self, community_id: str, severity: int) -> None:
249 """
250 Store severity for a community_id. Always keep the highest severity
251 value if the key already exists.
252
253 This method performs the comparison and update under the shard
254 lock to avoid extra cache hits and to ensure correctness under
255 concurrent updates.
256
257 Args:
258 community_id: String identifier for the community.
259 severity: Integer severity value.
260
261 Raises:
262 TypeError: If arguments are of incorrect types.
263 """
264 if not isinstance(community_id, str):
265 raise TypeError("community_id must be a str")
266 if not isinstance(severity, int):
267 raise TypeError("severity must be an int")
268
269 shard_idx = self._shard_for_key(community_id)
270 shard = self._shards[shard_idx]
271 lock: threading.RLock = shard["lock"]
272 with lock:
273 store: OrderedDict = shard["store"]
274 # Purge expired entries in this shard first
275 self._purge_expired_in_shard(shard_idx)
276
277 existing = store.get(community_id)
278 if existing is not None:
279 existing_val, existing_expires = existing
280 if existing_expires is not None and self._is_expired(existing_expires):
281 # Treat as missing
282 chosen = severity
283 else:
284 chosen = max(existing_val, severity)
285 # Move to MRU and set chosen value
286 if community_id in store:
287 store.move_to_end(community_id)
288 expires_at = self._expires_at()
289 store[community_id] = (chosen, expires_at)
290 else:
291 expires_at = self._expires_at()
292 store[community_id] = (severity, expires_at)
293
294 # Evict LRU in this shard while over capacity
295 shard_max = self._shard_max_sizes[shard_idx]
296 while len(store) > shard_max:
297 store.popitem(last=False)
298 self._record_eviction()
299
[docs]
300 def get_severity(self, community_id: str) -> Optional[int]:
301 """
302 Retrieve the cached severity for a community_id or None if missing.
303
304 Args:
305 community_id: String identifier for the community.
306
307 Returns:
308 The severity integer if present, otherwise None.
309
310 Raises:
311 TypeError: If `community_id` is not a str.
312 """
313 if not isinstance(community_id, str):
314 raise TypeError("community_id must be a str")
315
316 return self.get(community_id)
317
318
[docs]
319def reset_singletons() -> None:
320 """
321 Clear the internal singleton instance registry.
322
323 This is primarily intended for use in unit tests so different tests
324 can instantiate singletons with different constructor parameters and
325 get fresh instances. Use with care in production code as clearing
326 singletons while other threads hold references can be unsafe.
327 """
328 with SingletonMeta._lock:
329 SingletonMeta._instances.clear()