diff --git a/.gitignore b/.gitignore index fe1c538..3e9c748 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env __pycache__ *.db +cache diff --git a/cache.py b/cache.py index e840864..cd46a40 100644 --- a/cache.py +++ b/cache.py @@ -1,5 +1,6 @@ import json import os +import time from typing import Dict, Any, Optional import diskcache @@ -19,11 +20,34 @@ class Cache: def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: key = self._make_key(method, params) - return self.cache.get(key) + cached_data = self.cache.get(key) + + if cached_data is None: + return None + + # Check if cached data has TTL and if it's expired + if isinstance(cached_data, dict) and '_cache_expiry' in cached_data: + if time.time() > cached_data['_cache_expiry']: + # Remove expired entry + self.cache.delete(key) + return None + # Remove cache metadata before returning + response = cached_data.copy() + del response['_cache_expiry'] + return response + + return cached_data - def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any]) -> None: + def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any], ttl: Optional[int] = None) -> None: key = self._make_key(method, params) - self.cache.set(key, response) + + # Add TTL metadata if specified + if ttl is not None: + cached_response = response.copy() + cached_response['_cache_expiry'] = time.time() + ttl + self.cache.set(key, cached_response) + else: + self.cache.set(key, response) def size_check(self) -> Dict[str, Any]: stats = self.cache.stats() diff --git a/cache_policy.py b/cache_policy.py new file mode 100644 index 0000000..4b27d92 --- /dev/null +++ b/cache_policy.py @@ -0,0 +1,95 @@ +from typing import Dict, Any, Optional +import time + + +class CachePolicy: + """ + Determines caching behavior for Solana RPC methods based on their characteristics. + """ + + # Methods that return immutable data - cache indefinitely + CACHEABLE_IMMUTABLE = { + 'getGenesisHash' # Network genesis hash never changes + } + + # Methods with time-based TTL caching based on data change frequency + CACHEABLE_WITH_TTL = { + # Network/validator information - changes periodically + 'getVoteAccounts': 120, # Validator vote accounts change every few minutes + 'getSupply': 300, # Total SOL supply changes slowly + + # Epoch and network info - changes with epoch boundaries (~2-3 days) + 'getEpochInfo': 3600, # Current epoch info changes slowly + 'getInflationRate': 1800, # Inflation rate changes infrequently + 'getInflationGovernor': 3600, # Inflation governor params rarely change + + # Network constants - change very rarely or never + 'getEpochSchedule': 86400, # Epoch schedule rarely changes + 'getVersion': 3600, # RPC version changes occasionally + 'getIdentity': 3600, # Node identity changes rarely + + # Never change for the given parameters but will add new entry in the DB if the input parameters change + 'getBlock': 86400, + 'getTransaction':86400 + } + + def should_cache(self, method: str, params: Dict[str, Any]) -> bool: + """ + Determine if a method should be cached based on the method name and parameters. + + Args: + method: The RPC method name + params: The method parameters + + Returns: + True if the method should be cached, False otherwise + """ + if method in self.CACHEABLE_IMMUTABLE: + # For getBlock, only cache finalized blocks + if method == 'getBlock': + commitment = self._get_commitment(params) + return commitment == 'finalized' + return True + + if method in self.CACHEABLE_WITH_TTL: + return True + + # Default to not caching unknown methods + return False + + def get_cache_ttl(self, method: str, params: Dict[str, Any]) -> Optional[int]: + """ + Get the Time To Live (TTL) for a cached method in seconds. + + Args: + method: The RPC method name + params: The method parameters + + Returns: + TTL in seconds, or None for indefinite caching + """ + if method in self.CACHEABLE_IMMUTABLE: + return None # Cache indefinitely + + if method in self.CACHEABLE_WITH_TTL: + return self.CACHEABLE_WITH_TTL[method] + + return None + + def _get_commitment(self, params: Dict[str, Any]) -> str: + """ + Extract the commitment level from RPC parameters. + + Args: + params: The method parameters + + Returns: + The commitment level, defaults to 'processed' + """ + if isinstance(params, list) and len(params) > 1: + if isinstance(params[1], dict) and 'commitment' in params[1]: + return params[1]['commitment'] + elif isinstance(params, dict) and 'commitment' in params: + return params['commitment'] + + return 'processed' # Default commitment level diff --git a/main.py b/main.py index cf8785a..8499c99 100644 --- a/main.py +++ b/main.py @@ -36,7 +36,6 @@ def load_config() -> dict: return { "proxy_port": int(os.getenv("PROXY_PORT", 8545)), "cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)), - "disable_cache": os.getenv("DISABLE_CACHE", "true").lower() == "true", "backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)), "log_level": os.getenv("LOG_LEVEL", "INFO"), "error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"), @@ -56,7 +55,7 @@ def create_app(config: dict) -> web.Application: providers = create_providers() cache = Cache(size_limit_gb=config["cache_size_gb"]) error_logger = ErrorLogger(db_path=config["error_db_path"]) - router = Router(providers, cache, error_logger, config["disable_cache"]) + router = Router(providers, cache, error_logger) app['router'] = router app['config'] = config @@ -73,10 +72,7 @@ def main() -> None: logger = logging.getLogger(__name__) logger.info(f"Starting Solana RPC Proxy on port {config['proxy_port']}") - if config['disable_cache']: - logger.info("Cache is DISABLED - all responses will be fresh") - else: - logger.info(f"Cache size limit: {config['cache_size_gb']}GB") + logger.info(f"Intelligent caching enabled - Cache size limit: {config['cache_size_gb']}GB") logger.info(f"Provider backoff time: {config['backoff_minutes']} minutes") app = create_app(config) diff --git a/router.py b/router.py index 5410bc1..1c3f309 100644 --- a/router.py +++ b/router.py @@ -4,22 +4,26 @@ import logging from typing import Dict, Any, Optional, List from providers import Provider from cache import Cache +from cache_policy import CachePolicy from errors import ErrorLogger class Router: - def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger, disable_cache: bool = False): + def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger): self.providers = providers self.cache = cache + self.cache_policy = CachePolicy() self.error_logger = error_logger - self.disable_cache = disable_cache self.current_provider_index = 0 self.logger = logging.getLogger(__name__) async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: request = {"method": method, "params": params} - if not self.disable_cache: + # Check if this method should be cached based on intelligent caching policy + should_cache = self.cache_policy.should_cache(method, params) + + if should_cache: cached_response = self.cache.get(method, params) if cached_response: self.logger.debug(f"Cache hit for {method}") @@ -42,10 +46,15 @@ class Router: transformed_response["_cached"] = False transformed_response["_provider"] = provider.name - if not self.disable_cache: - self.cache.set(method, params, transformed_response) + # Cache the response if caching policy allows it + if should_cache: + ttl = self.cache_policy.get_cache_ttl(method, params) + self.cache.set(method, params, transformed_response, ttl) + cache_info = f" (cached {'indefinitely' if ttl is None else f'for {ttl}s'})" + self.logger.info(f"Request succeeded via {provider.name}{cache_info}") + else: + self.logger.info(f"Request succeeded via {provider.name} (not cached)") - self.logger.info(f"Request succeeded via {provider.name}") return transformed_response except Exception as error: @@ -61,11 +70,10 @@ class Router: def get_next_available_provider(self) -> Optional[Provider]: for _ in range(len(self.providers)): provider = self.providers[self.current_provider_index] + self.current_provider_index = (self.current_provider_index + 1) % len(self.providers) if provider.is_available(): return provider - else: - self.current_provider_index = (self.current_provider_index + 1) % len(self.providers) return None diff --git a/ws_proxy.py b/ws_proxy.py index dd73c66..a2f2e1e 100644 --- a/ws_proxy.py +++ b/ws_proxy.py @@ -110,11 +110,6 @@ class WebSocketProxy: transformed_response["_cached"] = False transformed_response["_provider"] = provider.name - method = transformed_response.get("method", "") - params = transformed_response.get("params", {}) - if method and params and not self.router.disable_cache: - self.router.cache.set(method, params, transformed_response) - await client_ws.send_str(json.dumps(transformed_response)) self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")