Handle CORS and selectively cache responses for appropriate methods #1

Merged
prathamesh merged 12 commits from sk-cors-fix into main 2025-08-01 10:37:07 +00:00
12 changed files with 352 additions and 137 deletions

View File

@ -1,4 +1,4 @@
# Provider endpoints and auth
# Provider endpoints and auth (optional)
ALCHEMY_API_KEY=your_key_here
HELIUS_API_KEY=your_key_here
QUICKNODE_ENDPOINT=your_endpoint.quiknode.pro
@ -6,7 +6,7 @@ QUICKNODE_TOKEN=your_token_here
# Proxy settings
PROXY_PORT=8545
CACHE_SIZE_GB=100
CACHE_SIZE_GB=1
BACKOFF_MINUTES=30
# Logging

3
.gitignore vendored
View File

@ -1 +1,4 @@
.env
__pycache__
*.db
cache

View File

@ -1 +1,23 @@
The trenches are brutal.
# Solana Proxy
## Setup
- Copy `.env.example` to `.env`:
```bash
cp .env.example .env
```
- The proxy will work without making any changes to the `.env` file but you can optionally set the API keys for different providers
## Run
- Start the proxy:
```bash
python3 main.py
```
- This will start the proxy with,
- RPC endpoint at: <http://0.0.0.0:8545>
- WS endpoint at: <ws://0.0.0.0:8545/ws>

View File

@ -1,11 +1,12 @@
import json
import os
import time
from typing import Dict, Any, Optional
import diskcache
class Cache:
def __init__(self, cache_dir: str = "./cache", size_limit_gb: int = 100):
def __init__(self, cache_dir: str = "./cache", size_limit_gb: int = 1):
self.cache_dir = cache_dir
self.size_limit_bytes = size_limit_gb * 1024 * 1024 * 1024
self.cache = diskcache.Cache(
@ -19,11 +20,34 @@ class Cache:
def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
key = self._make_key(method, params)
return self.cache.get(key)
cached_data = self.cache.get(key)
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any]) -> None:
if cached_data is None:
return None
# Check if cached data has TTL and if it's expired
if isinstance(cached_data, dict) and '_cache_expiry' in cached_data:
if time.time() > cached_data['_cache_expiry']:
# Remove expired entry
self.cache.delete(key)
return None
# Remove cache metadata before returning
response = cached_data.copy()
del response['_cache_expiry']
return response
return cached_data
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any], ttl: Optional[int] = None) -> None:
key = self._make_key(method, params)
self.cache.set(key, response)
# Add TTL metadata if specified
if ttl is not None:
cached_response = response.copy()
cached_response['_cache_expiry'] = time.time() + ttl
self.cache.set(key, cached_response)
else:
self.cache.set(key, response)
def size_check(self) -> Dict[str, Any]:
stats = self.cache.stats()
@ -33,4 +57,3 @@ class Cache:
"count": stats[0],
"limit_gb": self.size_limit_bytes / (1024 * 1024 * 1024)
}

96
cache_policy.py Normal file
View File

@ -0,0 +1,96 @@
from typing import Dict, Any, Optional
import time
class CachePolicy:
"""
Determines caching behavior for Solana RPC methods based on their characteristics.
"""
# Methods that return immutable data - cache indefinitely
CACHEABLE_IMMUTABLE = {
'getGenesisHash' # Network genesis hash never changes
}
# Methods with time-based TTL caching based on data change frequency
CACHEABLE_WITH_TTL = {
# Network/validator information - changes periodically
'getVoteAccounts': 120, # Validator vote accounts change every few minutes
'getSupply': 300, # Total SOL supply changes slowly
# Epoch and network info - changes with epoch boundaries (~2-3 days)
'getEpochInfo': 3600, # Current epoch info changes slowly
'getInflationRate': 1800, # Inflation rate changes infrequently
'getInflationGovernor': 3600, # Inflation governor params rarely change
# Network constants - change very rarely or never
'getEpochSchedule': 86400, # Epoch schedule rarely changes
'getVersion': 3600, # RPC version changes occasionally
'getIdentity': 3600, # Node identity changes rarely
# Never change for the given parameters but will add new entry in the DB if the input parameters change
'getBlock': 86400,
'getTransaction':86400
}
def should_cache(self, method: str, params: Dict[str, Any]) -> bool:
"""
Determine if a method should be cached based on the method name and parameters.
Args:
method: The RPC method name
params: The method parameters
Returns:
True if the method should be cached, False otherwise
"""
if method in self.CACHEABLE_WITH_TTL:
# For getBlock, only cache finalized blocks
if method == 'getBlock':
commitment = self._get_commitment(params)
return commitment == 'finalized'
return True
if method in self.CACHEABLE_IMMUTABLE:
return True
# Default to not caching unknown methods
return False
def get_cache_ttl(self, method: str, params: Dict[str, Any]) -> Optional[int]:
"""
Get the Time To Live (TTL) for a cached method in seconds.
Args:
method: The RPC method name
params: The method parameters
Returns:
TTL in seconds, or None for indefinite caching
"""
if method in self.CACHEABLE_IMMUTABLE:
return None # Cache indefinitely
if method in self.CACHEABLE_WITH_TTL:
return self.CACHEABLE_WITH_TTL[method]
return None
def _get_commitment(self, params: Dict[str, Any]) -> str:
"""
Extract the commitment level from RPC parameters.
Args:
params: The method parameters
Returns:
The commitment level, defaults to 'processed'
"""
if isinstance(params, list) and len(params) > 1:
if isinstance(params[1], dict) and 'commitment' in params[1]:
return params[1]['commitment']
elif isinstance(params, dict) and 'commitment' in params:
return params['commitment']
return 'processed' # Default commitment level

View File

@ -34,7 +34,7 @@ Provider class:
Cache class:
- get(method: str, params: dict) -> Optional[response]
- set(method: str, params: dict, response: dict) -> None
- size_check() -> None # Enforce 100GB limit
- size_check() -> None # Enforce 1GB limit
- clear_oldest() -> None # LRU eviction
```
@ -42,7 +42,7 @@ Cache class:
- Use `diskcache` library for simplicity
- Key format: `f"{method}:{json.dumps(params, sort_keys=True)}"`
- Store both HTTP responses and WebSocket messages
- Implement 100GB limit with LRU eviction
- Implement 1GB limit with LRU eviction
### 3. Error Logger Module (`errors.py`)
**Purpose**: SQLite-based error logging with UUID tracking
@ -146,7 +146,7 @@ QUICKNODE_TOKEN=your_token_here
# Proxy settings
PROXY_PORT=8545
CACHE_SIZE_GB=100
CACHE_SIZE_GB=1
BACKOFF_MINUTES=30
# Logging
@ -227,7 +227,7 @@ Happy-path end-to-end tests only:
## Deployment Considerations
1. **Cache Storage**: Need ~100GB disk space
1. **Cache Storage**: Need ~1GB disk space
2. **Memory Usage**: Keep minimal, use disk cache
3. **Concurrent Clients**: Basic round-robin if multiple connect
4. **Monitoring**: Log all errors, provide error IDs
@ -273,7 +273,7 @@ aiohttp-cors==0.7.0
1. Single endpoint proxies to 5 providers
2. Automatic failover works
3. Responses are cached (up to 100GB)
3. Responses are cached (up to 1GB)
4. Errors logged with retrievable IDs
5. Both HTTP and WebSocket work
6. Response format is unified

View File

@ -38,6 +38,7 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
logger.info(f"Handling RPC request: {method}")
response = await router.route_request(method, params)
response["id"] = request_id
return web.json_response(response)
@ -64,10 +65,5 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
}, status=500)
def setup_routes(app: web.Application) -> None:
app.router.add_post('/', handle_rpc_request)

26
main.py
View File

@ -1,5 +1,6 @@
import os
import logging
import asyncio
from dotenv import load_dotenv
from aiohttp import web
from providers import create_providers
@ -10,12 +11,31 @@ from http_proxy import setup_routes
from ws_proxy import setup_ws_routes
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers to all responses"""
if request.method == 'OPTIONS':
# Handle preflight requests
return web.Response(headers={
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '86400'
})
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
return response
def load_config() -> dict:
load_dotenv()
return {
"proxy_port": int(os.getenv("PROXY_PORT", 8545)),
"cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)),
"cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 1)),
"backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)),
"log_level": os.getenv("LOG_LEVEL", "INFO"),
"error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"),
@ -30,7 +50,7 @@ def setup_logging(log_level: str) -> None:
def create_app(config: dict) -> web.Application:
app = web.Application()
app = web.Application(middlewares=[cors_middleware])
providers = create_providers()
cache = Cache(size_limit_gb=config["cache_size_gb"])
@ -46,8 +66,6 @@ def create_app(config: dict) -> web.Application:
return app
def main() -> None:
config = load_config()
setup_logging(config["log_level"])

View File

@ -105,9 +105,9 @@ class SolanaPublicProvider(Provider):
def create_providers() -> list[Provider]:
return [
SolanaPublicProvider(),
AlchemyProvider(),
PublicNodeProvider(),
HeliusProvider(),
QuickNodeProvider(),
SolanaPublicProvider()
]

View File

@ -4,6 +4,7 @@ import logging
from typing import Dict, Any, Optional, List
from providers import Provider
from cache import Cache
from cache_policy import CachePolicy
from errors import ErrorLogger
@ -11,6 +12,7 @@ class Router:
def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger):
self.providers = providers
self.cache = cache
self.cache_policy = CachePolicy()
self.error_logger = error_logger
self.current_provider_index = 0
self.logger = logging.getLogger(__name__)
@ -18,12 +20,16 @@ class Router:
async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
request = {"method": method, "params": params}
cached_response = self.cache.get(method, params)
if cached_response:
self.logger.debug(f"Cache hit for {method}")
cached_response["_cached"] = True
cached_response["_provider"] = "cache"
return cached_response
# Check if this method should be cached based on caching policy
should_cache = self.cache_policy.should_cache(method, params)
if should_cache:
cached_response = self.cache.get(method, params)
if cached_response:
self.logger.debug(f"Cache hit for {method}")
cached_response["_cached"] = True
cached_response["_provider"] = "cache"
return cached_response
for attempt in range(len(self.providers)):
provider = self.get_next_available_provider()
@ -40,14 +46,27 @@ class Router:
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
self.cache.set(method, params, transformed_response)
self.logger.info(f"Request succeeded via {provider.name}")
# Cache the response if caching policy allows it
if should_cache:
ttl = self.cache_policy.get_cache_ttl(method, params)
self.cache.set(method, params, transformed_response, ttl)
cache_info = f" (cached {'indefinitely' if ttl is None else f'for {ttl}s'})"
self.logger.info(f"Request succeeded via {provider.name}{cache_info}")
else:
self.logger.info(f"Request succeeded via {provider.name} (not cached)")
return transformed_response
except Exception as error:
error_id = self.error_logger.log_error(provider.name, request, error)
self.logger.warning(f"Provider {provider.name} failed: {error} (ID: {error_id})")
provider.mark_failed()
# Only mark provider as failed for server/network issues, not RPC errors
if await self._is_server_failure(provider, error):
provider.mark_failed()
self.logger.warning(f"Provider {provider.name} marked as failed due to server issue")
else:
self.logger.debug(f"Provider {provider.name} had RPC error but server is available")
return self._create_error_response(
"All providers failed to handle the request",
@ -64,6 +83,29 @@ class Router:
return None
async def _is_server_failure(self, provider: Provider, error: Exception) -> bool:
"""
Check if the provider server is actually down by making a simple health check.
Only mark as failed if server is unreachable.
"""
try:
# Quick health check with minimal timeout
timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
# Try a simple HTTP GET to check server availability
from urllib.parse import urlparse
parsed_url = urlparse(provider.http_url)
health_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
async with session.get(health_url) as response:
# Server responded (even with error codes), so it's alive
return False
except Exception as health_error:
# Server is actually unreachable
self.logger.debug(f"Health check failed for {provider.name}: {health_error}")
return True
async def _make_request(self, provider: Provider, request: Dict[str, Any]) -> Dict[str, Any]:
transformed_request = provider.transform_request(request)

View File

@ -24,17 +24,24 @@ class WebSocketProxy:
return ws
try:
provider_ws = await self._connect_to_provider(provider)
if not provider_ws:
provider_connection = await self._connect_to_provider(provider)
if not provider_connection:
await ws.close(code=1011, message=b'Failed to connect to provider')
return ws
provider_ws, provider_session = provider_connection
await asyncio.gather(
self._proxy_client_to_provider(ws, provider_ws, provider),
self._proxy_provider_to_client(provider_ws, ws, provider),
return_exceptions=True
)
# Clean up provider connection
if not provider_ws.closed:
await provider_ws.close()
await provider_session.close()
except Exception as e:
self.logger.error(f"WebSocket proxy error: {e}")
@ -44,14 +51,18 @@ class WebSocketProxy:
return ws
async def _connect_to_provider(self, provider) -> Optional[object]:
async def _connect_to_provider(self, provider) -> Optional[tuple]:
session = None
try:
session = ClientSession()
self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}")
ws = await session.ws_connect(provider.ws_url)
self.logger.info(f"Connected to provider {provider.name} WebSocket")
return ws
self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}")
return (ws, session)
except Exception as e:
self.logger.error(f"Failed to connect to provider {provider.name}: {e}")
self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}")
if session:
await session.close()
return None
async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None:
@ -59,11 +70,14 @@ class WebSocketProxy:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
method = data.get('method', 'unknown')
self.logger.debug(f"Received from client: {data}")
transformed_request = provider.transform_request(data)
await provider_ws.send_str(json.dumps(transformed_request))
self.logger.debug(f"Forwarded message to {provider.name}: {data.get('method', 'unknown')}")
self.logger.debug(f"Forwarded message to {provider.name}: {method}")
except json.JSONDecodeError:
self.logger.warning("Received invalid JSON from client")
@ -83,6 +97,7 @@ class WebSocketProxy:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
self.logger.debug(f"Received from provider {provider.name}: {data}")
transformed_response = provider.transform_response(data)
@ -90,31 +105,31 @@ class WebSocketProxy:
subscription_id = transformed_response.get("result")
if subscription_id:
self.subscription_mappings[str(subscription_id)] = provider.name
self.logger.debug(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}")
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
method = transformed_response.get("method", "")
params = transformed_response.get("params", {})
if method and params:
self.router.cache.set(method, params, transformed_response)
await client_ws.send_str(json.dumps(transformed_response))
self.logger.debug(f"Forwarded response from {provider.name}")
self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")
except json.JSONDecodeError:
self.logger.warning(f"Received invalid JSON from provider {provider.name}")
except Exception as e:
self.logger.error(f"Error forwarding from provider: {e}")
break
# Don't break here - continue processing other messages
continue
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
self.logger.warning(f"Provider WebSocket connection closed from {provider.name}")
break
self.logger.warning(f"Provider-to-client message loop ended for {provider.name}")
async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
router: Router = request.app['router']