Implement selective caching for RPC method responses

This commit is contained in:
Shreerang Kale 2025-08-01 14:23:23 +05:30
parent 3da690f2d1
commit 23d69eec14
6 changed files with 141 additions and 22 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.env .env
__pycache__ __pycache__
*.db *.db
cache

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import time
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import diskcache import diskcache
@ -19,10 +20,33 @@ class Cache:
def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
key = self._make_key(method, params) key = self._make_key(method, params)
return self.cache.get(key) cached_data = self.cache.get(key)
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any]) -> None: if cached_data is None:
return None
# Check if cached data has TTL and if it's expired
if isinstance(cached_data, dict) and '_cache_expiry' in cached_data:
if time.time() > cached_data['_cache_expiry']:
# Remove expired entry
self.cache.delete(key)
return None
# Remove cache metadata before returning
response = cached_data.copy()
del response['_cache_expiry']
return response
return cached_data
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any], ttl: Optional[int] = None) -> None:
key = self._make_key(method, params) key = self._make_key(method, params)
# Add TTL metadata if specified
if ttl is not None:
cached_response = response.copy()
cached_response['_cache_expiry'] = time.time() + ttl
self.cache.set(key, cached_response)
else:
self.cache.set(key, response) self.cache.set(key, response)
def size_check(self) -> Dict[str, Any]: def size_check(self) -> Dict[str, Any]:

95
cache_policy.py Normal file
View File

@ -0,0 +1,95 @@
from typing import Dict, Any, Optional
import time
class CachePolicy:
"""
Determines caching behavior for Solana RPC methods based on their characteristics.
"""
# Methods that return immutable data - cache indefinitely
CACHEABLE_IMMUTABLE = {
'getGenesisHash' # Network genesis hash never changes
}
# Methods with time-based TTL caching based on data change frequency
CACHEABLE_WITH_TTL = {
# Network/validator information - changes periodically
'getVoteAccounts': 120, # Validator vote accounts change every few minutes
'getSupply': 300, # Total SOL supply changes slowly
# Epoch and network info - changes with epoch boundaries (~2-3 days)
'getEpochInfo': 3600, # Current epoch info changes slowly
'getInflationRate': 1800, # Inflation rate changes infrequently
'getInflationGovernor': 3600, # Inflation governor params rarely change
# Network constants - change very rarely or never
'getEpochSchedule': 86400, # Epoch schedule rarely changes
'getVersion': 3600, # RPC version changes occasionally
'getIdentity': 3600, # Node identity changes rarely
# Never change for the given parameters but will add new entry in the DB if the input parameters change
'getBlock': 86400,
'getTransaction':86400
}
def should_cache(self, method: str, params: Dict[str, Any]) -> bool:
"""
Determine if a method should be cached based on the method name and parameters.
Args:
method: The RPC method name
params: The method parameters
Returns:
True if the method should be cached, False otherwise
"""
if method in self.CACHEABLE_IMMUTABLE:
# For getBlock, only cache finalized blocks
if method == 'getBlock':
commitment = self._get_commitment(params)
return commitment == 'finalized'
return True
if method in self.CACHEABLE_WITH_TTL:
return True
# Default to not caching unknown methods
return False
def get_cache_ttl(self, method: str, params: Dict[str, Any]) -> Optional[int]:
"""
Get the Time To Live (TTL) for a cached method in seconds.
Args:
method: The RPC method name
params: The method parameters
Returns:
TTL in seconds, or None for indefinite caching
"""
if method in self.CACHEABLE_IMMUTABLE:
return None # Cache indefinitely
if method in self.CACHEABLE_WITH_TTL:
return self.CACHEABLE_WITH_TTL[method]
return None
def _get_commitment(self, params: Dict[str, Any]) -> str:
"""
Extract the commitment level from RPC parameters.
Args:
params: The method parameters
Returns:
The commitment level, defaults to 'processed'
"""
if isinstance(params, list) and len(params) > 1:
if isinstance(params[1], dict) and 'commitment' in params[1]:
return params[1]['commitment']
elif isinstance(params, dict) and 'commitment' in params:
return params['commitment']
return 'processed' # Default commitment level

View File

@ -36,7 +36,6 @@ def load_config() -> dict:
return { return {
"proxy_port": int(os.getenv("PROXY_PORT", 8545)), "proxy_port": int(os.getenv("PROXY_PORT", 8545)),
"cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)), "cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)),
"disable_cache": os.getenv("DISABLE_CACHE", "true").lower() == "true",
"backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)), "backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)),
"log_level": os.getenv("LOG_LEVEL", "INFO"), "log_level": os.getenv("LOG_LEVEL", "INFO"),
"error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"), "error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"),
@ -56,7 +55,7 @@ def create_app(config: dict) -> web.Application:
providers = create_providers() providers = create_providers()
cache = Cache(size_limit_gb=config["cache_size_gb"]) cache = Cache(size_limit_gb=config["cache_size_gb"])
error_logger = ErrorLogger(db_path=config["error_db_path"]) error_logger = ErrorLogger(db_path=config["error_db_path"])
router = Router(providers, cache, error_logger, config["disable_cache"]) router = Router(providers, cache, error_logger)
app['router'] = router app['router'] = router
app['config'] = config app['config'] = config
@ -73,10 +72,7 @@ def main() -> None:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.info(f"Starting Solana RPC Proxy on port {config['proxy_port']}") logger.info(f"Starting Solana RPC Proxy on port {config['proxy_port']}")
if config['disable_cache']: logger.info(f"Intelligent caching enabled - Cache size limit: {config['cache_size_gb']}GB")
logger.info("Cache is DISABLED - all responses will be fresh")
else:
logger.info(f"Cache size limit: {config['cache_size_gb']}GB")
logger.info(f"Provider backoff time: {config['backoff_minutes']} minutes") logger.info(f"Provider backoff time: {config['backoff_minutes']} minutes")
app = create_app(config) app = create_app(config)

View File

@ -4,22 +4,26 @@ import logging
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from providers import Provider from providers import Provider
from cache import Cache from cache import Cache
from cache_policy import CachePolicy
from errors import ErrorLogger from errors import ErrorLogger
class Router: class Router:
def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger, disable_cache: bool = False): def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger):
self.providers = providers self.providers = providers
self.cache = cache self.cache = cache
self.cache_policy = CachePolicy()
self.error_logger = error_logger self.error_logger = error_logger
self.disable_cache = disable_cache
self.current_provider_index = 0 self.current_provider_index = 0
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
request = {"method": method, "params": params} request = {"method": method, "params": params}
if not self.disable_cache: # Check if this method should be cached based on intelligent caching policy
should_cache = self.cache_policy.should_cache(method, params)
if should_cache:
cached_response = self.cache.get(method, params) cached_response = self.cache.get(method, params)
if cached_response: if cached_response:
self.logger.debug(f"Cache hit for {method}") self.logger.debug(f"Cache hit for {method}")
@ -42,10 +46,15 @@ class Router:
transformed_response["_cached"] = False transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name transformed_response["_provider"] = provider.name
if not self.disable_cache: # Cache the response if caching policy allows it
self.cache.set(method, params, transformed_response) if should_cache:
ttl = self.cache_policy.get_cache_ttl(method, params)
self.cache.set(method, params, transformed_response, ttl)
cache_info = f" (cached {'indefinitely' if ttl is None else f'for {ttl}s'})"
self.logger.info(f"Request succeeded via {provider.name}{cache_info}")
else:
self.logger.info(f"Request succeeded via {provider.name} (not cached)")
self.logger.info(f"Request succeeded via {provider.name}")
return transformed_response return transformed_response
except Exception as error: except Exception as error:
@ -61,11 +70,10 @@ class Router:
def get_next_available_provider(self) -> Optional[Provider]: def get_next_available_provider(self) -> Optional[Provider]:
for _ in range(len(self.providers)): for _ in range(len(self.providers)):
provider = self.providers[self.current_provider_index] provider = self.providers[self.current_provider_index]
self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
if provider.is_available(): if provider.is_available():
return provider return provider
else:
self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
return None return None

View File

@ -110,11 +110,6 @@ class WebSocketProxy:
transformed_response["_cached"] = False transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name transformed_response["_provider"] = provider.name
method = transformed_response.get("method", "")
params = transformed_response.get("params", {})
if method and params and not self.router.disable_cache:
self.router.cache.set(method, params, transformed_response)
await client_ws.send_str(json.dumps(transformed_response)) await client_ws.send_str(json.dumps(transformed_response))
self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}") self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")