diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..845f69b --- /dev/null +++ b/.env.example @@ -0,0 +1,14 @@ +# Provider endpoints and auth +ALCHEMY_API_KEY=your_key_here +HELIUS_API_KEY=your_key_here +QUICKNODE_ENDPOINT=your_endpoint.quiknode.pro +QUICKNODE_TOKEN=your_token_here + +# Proxy settings +PROXY_PORT=8545 +CACHE_SIZE_GB=100 +BACKOFF_MINUTES=30 + +# Logging +LOG_LEVEL=INFO +ERROR_DB_PATH=./errors.db \ No newline at end of file diff --git a/cache.py b/cache.py new file mode 100644 index 0000000..e840864 --- /dev/null +++ b/cache.py @@ -0,0 +1,36 @@ +import json +import os +from typing import Dict, Any, Optional +import diskcache + + +class Cache: + def __init__(self, cache_dir: str = "./cache", size_limit_gb: int = 100): + self.cache_dir = cache_dir + self.size_limit_bytes = size_limit_gb * 1024 * 1024 * 1024 + self.cache = diskcache.Cache( + directory=cache_dir, + size_limit=self.size_limit_bytes, + eviction_policy='least-recently-used' + ) + + def _make_key(self, method: str, params: Dict[str, Any]) -> str: + return f"{method}:{json.dumps(params, sort_keys=True)}" + + def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + key = self._make_key(method, params) + return self.cache.get(key) + + def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any]) -> None: + key = self._make_key(method, params) + self.cache.set(key, response) + + def size_check(self) -> Dict[str, Any]: + stats = self.cache.stats() + return { + "size_bytes": stats[1], + "size_gb": stats[1] / (1024 * 1024 * 1024), + "count": stats[0], + "limit_gb": self.size_limit_bytes / (1024 * 1024 * 1024) + } + diff --git a/errors.py b/errors.py new file mode 100644 index 0000000..cf0d582 --- /dev/null +++ b/errors.py @@ -0,0 +1,74 @@ +import sqlite3 +import json +import uuid +import traceback +from datetime import datetime +from typing import Dict, Any, Optional + + +class ErrorLogger: + def __init__(self, db_path: str = "./errors.db"): + self.db_path = db_path + self.setup_db() + + def setup_db(self) -> None: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS errors ( + id TEXT PRIMARY KEY, + timestamp DATETIME, + provider TEXT, + request_method TEXT, + request_params TEXT, + error_type TEXT, + error_message TEXT, + error_traceback TEXT + ) + """) + conn.commit() + + def log_error(self, provider: str, request: Dict[str, Any], error: Exception) -> str: + error_id = str(uuid.uuid4()) + timestamp = datetime.now() + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO errors ( + id, timestamp, provider, request_method, request_params, + error_type, error_message, error_traceback + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, ( + error_id, + timestamp, + provider, + request.get("method", "unknown"), + json.dumps(request.get("params", {})), + type(error).__name__, + str(error), + traceback.format_exc() + )) + conn.commit() + + return error_id + + def get_error(self, error_id: str) -> Optional[Dict[str, Any]]: + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute( + "SELECT * FROM errors WHERE id = ?", (error_id,) + ) + row = cursor.fetchone() + + if row: + return { + "id": row["id"], + "timestamp": row["timestamp"], + "provider": row["provider"], + "request_method": row["request_method"], + "request_params": json.loads(row["request_params"]), + "error_type": row["error_type"], + "error_message": row["error_message"], + "error_traceback": row["error_traceback"] + } + return None + diff --git a/http_proxy.py b/http_proxy.py new file mode 100644 index 0000000..fafac6f --- /dev/null +++ b/http_proxy.py @@ -0,0 +1,73 @@ +import json +import logging +from aiohttp import web, ClientSession +from router import Router + + +async def handle_rpc_request(request: web.Request) -> web.Response: + router: Router = request.app['router'] + logger = logging.getLogger(__name__) + + try: + body = await request.json() + + if not isinstance(body, dict): + return web.json_response({ + "jsonrpc": "2.0", + "id": body.get("id", 1) if isinstance(body, dict) else 1, + "error": { + "code": -32600, + "message": "Invalid Request" + } + }, status=400) + + method = body.get("method") + params = body.get("params", []) + request_id = body.get("id", 1) + + if not method: + return web.json_response({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32600, + "message": "Missing method" + } + }, status=400) + + logger.info(f"Handling RPC request: {method}") + + response = await router.route_request(method, params) + response["id"] = request_id + + return web.json_response(response) + + except json.JSONDecodeError: + return web.json_response({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32700, + "message": "Parse error" + } + }, status=400) + + except Exception as e: + logger.error(f"Unexpected error: {e}") + return web.json_response({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32603, + "message": "Internal error" + } + }, status=500) + + + + + +def setup_routes(app: web.Application) -> None: + app.router.add_post('/', handle_rpc_request) + + diff --git a/main.py b/main.py new file mode 100644 index 0000000..63678c2 --- /dev/null +++ b/main.py @@ -0,0 +1,70 @@ +import os +import logging +from dotenv import load_dotenv +from aiohttp import web +from providers import create_providers +from cache import Cache +from errors import ErrorLogger +from router import Router +from http_proxy import setup_routes +from ws_proxy import setup_ws_routes + + +def load_config() -> dict: + load_dotenv() + + return { + "proxy_port": int(os.getenv("PROXY_PORT", 8545)), + "cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)), + "backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)), + "log_level": os.getenv("LOG_LEVEL", "INFO"), + "error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"), + } + + +def setup_logging(log_level: str) -> None: + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + +def create_app(config: dict) -> web.Application: + app = web.Application() + + providers = create_providers() + cache = Cache(size_limit_gb=config["cache_size_gb"]) + error_logger = ErrorLogger(db_path=config["error_db_path"]) + router = Router(providers, cache, error_logger) + + app['router'] = router + app['config'] = config + + setup_routes(app) + setup_ws_routes(app) + + return app + + + + +def main() -> None: + config = load_config() + setup_logging(config["log_level"]) + + logger = logging.getLogger(__name__) + logger.info(f"Starting Solana RPC Proxy on port {config['proxy_port']}") + logger.info(f"Cache size limit: {config['cache_size_gb']}GB") + logger.info(f"Provider backoff time: {config['backoff_minutes']} minutes") + + app = create_app(config) + + web.run_app( + app, + host='0.0.0.0', + port=config["proxy_port"] + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/normalizer.py b/normalizer.py new file mode 100644 index 0000000..dfb94c0 --- /dev/null +++ b/normalizer.py @@ -0,0 +1,28 @@ +from typing import Dict, Any + + +def normalize_response(provider: str, response: Dict[str, Any]) -> Dict[str, Any]: + normalized = response.copy() + + # Ensure consistent field names + if "result" in normalized and normalized["result"] is None: + # Some providers return null, others omit the field + pass + + # Handle null vs missing fields consistently + if "error" in normalized and normalized["error"] is None: + del normalized["error"] + + return normalized + + +def normalize_error(error: Exception, error_id: str) -> Dict[str, Any]: + return { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32603, + "message": str(error), + "data": {"error_id": error_id} + } + } \ No newline at end of file diff --git a/providers.py b/providers.py new file mode 100644 index 0000000..7da9c66 --- /dev/null +++ b/providers.py @@ -0,0 +1,113 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Dict, Any, Optional +import os + + +class Provider(ABC): + def __init__(self, name: str): + self.name = name + self.backoff_until: Optional[datetime] = None + + @property + @abstractmethod + def http_url(self) -> str: + pass + + @property + @abstractmethod + def ws_url(self) -> str: + pass + + def transform_request(self, request: Dict[str, Any]) -> Dict[str, Any]: + return request + + def transform_response(self, response: Dict[str, Any]) -> Dict[str, Any]: + return response + + def is_available(self) -> bool: + if self.backoff_until is None: + return True + return datetime.now() > self.backoff_until + + def mark_failed(self, backoff_minutes: int = 30) -> None: + self.backoff_until = datetime.now() + timedelta(minutes=backoff_minutes) + + +class AlchemyProvider(Provider): + def __init__(self): + super().__init__("alchemy") + self.api_key = os.getenv("ALCHEMY_API_KEY", "") + + @property + def http_url(self) -> str: + return f"https://solana-mainnet.g.alchemy.com/v2/{self.api_key}" + + @property + def ws_url(self) -> str: + return f"wss://solana-mainnet.g.alchemy.com/v2/{self.api_key}" + + +class PublicNodeProvider(Provider): + def __init__(self): + super().__init__("publicnode") + + @property + def http_url(self) -> str: + return "https://solana-rpc.publicnode.com" + + @property + def ws_url(self) -> str: + return "wss://solana-rpc.publicnode.com" + + +class HeliusProvider(Provider): + def __init__(self): + super().__init__("helius") + self.api_key = os.getenv("HELIUS_API_KEY", "") + + @property + def http_url(self) -> str: + return f"https://mainnet.helius-rpc.com/?api-key={self.api_key}" + + @property + def ws_url(self) -> str: + return f"wss://mainnet.helius-rpc.com/?api-key={self.api_key}" + + +class QuickNodeProvider(Provider): + def __init__(self): + super().__init__("quicknode") + self.endpoint = os.getenv("QUICKNODE_ENDPOINT", "") + self.token = os.getenv("QUICKNODE_TOKEN", "") + + @property + def http_url(self) -> str: + return f"https://{self.endpoint}/{self.token}/" + + @property + def ws_url(self) -> str: + return f"wss://{self.endpoint}/{self.token}/" + + +class SolanaPublicProvider(Provider): + def __init__(self): + super().__init__("solana_public") + + @property + def http_url(self) -> str: + return "https://api.mainnet-beta.solana.com" + + @property + def ws_url(self) -> str: + return "wss://api.mainnet-beta.solana.com" + + +def create_providers() -> list[Provider]: + return [ + AlchemyProvider(), + PublicNodeProvider(), + HeliusProvider(), + QuickNodeProvider(), + SolanaPublicProvider() + ] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dcb87f8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "solana-proxy" +version = "0.1.0" +description = "A Python-based reverse proxy for Solana RPC endpoints with automatic failover and caching" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "aiohttp==3.9.0", + "python-dotenv==1.0.0", + "diskcache==5.6.0", +] + +[project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "aiohttp-client-manager>=1.1.0", +] + +[project.scripts] +solana-proxy = "main:main" \ No newline at end of file diff --git a/router.py b/router.py new file mode 100644 index 0000000..ee5b535 --- /dev/null +++ b/router.py @@ -0,0 +1,105 @@ +import aiohttp +import json +import logging +from typing import Dict, Any, Optional, List +from providers import Provider +from cache import Cache +from errors import ErrorLogger + + +class Router: + def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger): + self.providers = providers + self.cache = cache + self.error_logger = error_logger + self.current_provider_index = 0 + self.logger = logging.getLogger(__name__) + + async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]: + request = {"method": method, "params": params} + + cached_response = self.cache.get(method, params) + if cached_response: + self.logger.debug(f"Cache hit for {method}") + cached_response["_cached"] = True + cached_response["_provider"] = "cache" + return cached_response + + for attempt in range(len(self.providers)): + provider = self.get_next_available_provider() + if not provider: + return self._create_error_response( + "All providers are currently unavailable", + "NO_AVAILABLE_PROVIDERS" + ) + + try: + response = await self._make_request(provider, request) + + transformed_response = provider.transform_response(response) + transformed_response["_cached"] = False + transformed_response["_provider"] = provider.name + + self.cache.set(method, params, transformed_response) + self.logger.info(f"Request succeeded via {provider.name}") + return transformed_response + + except Exception as error: + error_id = self.error_logger.log_error(provider.name, request, error) + self.logger.warning(f"Provider {provider.name} failed: {error} (ID: {error_id})") + provider.mark_failed() + + return self._create_error_response( + "All providers failed to handle the request", + "ALL_PROVIDERS_FAILED" + ) + + def get_next_available_provider(self) -> Optional[Provider]: + for _ in range(len(self.providers)): + provider = self.providers[self.current_provider_index] + self.current_provider_index = (self.current_provider_index + 1) % len(self.providers) + + if provider.is_available(): + return provider + + return None + + async def _make_request(self, provider: Provider, request: Dict[str, Any]) -> Dict[str, Any]: + transformed_request = provider.transform_request(request) + + rpc_request = { + "jsonrpc": "2.0", + "id": 1, + "method": transformed_request["method"], + "params": transformed_request["params"] + } + + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + provider.http_url, + json=rpc_request, + headers={"Content-Type": "application/json"} + ) as response: + if response.status != 200: + raise Exception(f"HTTP {response.status}: {await response.text()}") + + result = await response.json() + + if "error" in result: + raise Exception(f"RPC Error: {result['error']}") + + return result + + def _create_error_response(self, message: str, code: str) -> Dict[str, Any]: + return { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32000, + "message": message, + "data": {"proxy_error_code": code} + }, + "_cached": False, + "_provider": "proxy_error" + } \ No newline at end of file diff --git a/test_e2e.py b/test_e2e.py new file mode 100644 index 0000000..9616d16 --- /dev/null +++ b/test_e2e.py @@ -0,0 +1,180 @@ +import asyncio +import json +import aiohttp +import pytest +from main import create_app, load_config + + +@pytest.fixture +async def app(): + config = load_config() + config["proxy_port"] = 8546 # Use different port for testing + app = create_app(config) + return app + + +@pytest.fixture +async def client(app, aiohttp_client): + return await aiohttp_client(app) + + +async def test_http_proxy_getBalance(client): + """Test basic HTTP RPC request""" + request_data = { + "jsonrpc": "2.0", + "id": 1, + "method": "getBalance", + "params": ["11111111111111111111111111111112"] # System program ID + } + + async with client.post('/', json=request_data) as response: + assert response.status == 200 + data = await response.json() + + assert data["jsonrpc"] == "2.0" + assert data["id"] == 1 + assert "result" in data or "error" in data + assert "_provider" in data + assert data["_provider"] != "proxy_error" + + +async def test_cache_functionality(client): + """Test that responses are cached""" + request_data = { + "jsonrpc": "2.0", + "id": 1, + "method": "getBalance", + "params": ["11111111111111111111111111111112"] + } + + # First request + async with client.post('/', json=request_data) as response: + data1 = await response.json() + provider1 = data1["_provider"] + cached1 = data1["_cached"] + + # Second request (should be cached) + async with client.post('/', json=request_data) as response: + data2 = await response.json() + provider2 = data2["_provider"] + cached2 = data2["_cached"] + + # First request shouldn't be cached, second should be + assert not cached1 + assert cached2 + assert provider2 == "cache" + + + +async def test_invalid_json_request(client): + """Test invalid JSON handling""" + async with client.post('/', data="invalid json") as response: + assert response.status == 400 + data = await response.json() + + assert "error" in data + assert data["error"]["code"] == -32700 + + +async def test_missing_method(client): + """Test missing method handling""" + request_data = { + "jsonrpc": "2.0", + "id": 1, + "params": [] + } + + async with client.post('/', json=request_data) as response: + assert response.status == 400 + data = await response.json() + + assert "error" in data + assert data["error"]["code"] == -32600 + + +async def test_websocket_connection(client): + """Test WebSocket connection establishment""" + try: + async with client.ws_connect('/ws') as ws: + # Send a simple subscription request + subscribe_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "accountSubscribe", + "params": [ + "11111111111111111111111111111112", + {"encoding": "jsonParsed"} + ] + } + + await ws.send_str(json.dumps(subscribe_request)) + + # Wait for response (with timeout) + try: + msg = await asyncio.wait_for(ws.receive(), timeout=10.0) + if msg.type == aiohttp.WSMsgType.TEXT: + response = json.loads(msg.data) + assert "result" in response or "error" in response + if "_provider" in response: + assert response["_provider"] != "proxy_error" + + except asyncio.TimeoutError: + # WebSocket might timeout if providers are unavailable + # This is acceptable for the test + pass + + except Exception as e: + # WebSocket connection might fail if no providers are available + # This is acceptable for testing environment + pytest.skip(f"WebSocket test skipped due to provider unavailability: {e}") + + +if __name__ == "__main__": + # Run tests manually if executed directly + import sys + + async def run_tests(): + config = load_config() + config["proxy_port"] = 8546 + app = create_app(config) + + # Start the application + runner = aiohttp.web.AppRunner(app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, 'localhost', config["proxy_port"]) + await site.start() + + print(f"Test server started on port {config['proxy_port']}") + + try: + # Simple manual test + async with aiohttp.ClientSession() as session: + request_data = { + "jsonrpc": "2.0", + "id": 1, + "method": "getBalance", + "params": ["11111111111111111111111111111112"] + } + + async with session.post( + f"http://localhost:{config['proxy_port']}/", + json=request_data + ) as response: + data = await response.json() + print("Response:", json.dumps(data, indent=2)) + + # Test second request (should be cached) + async with session.post( + f"http://localhost:{config['proxy_port']}/", + json=request_data + ) as response: + cached_data = await response.json() + print("Cached response:", json.dumps(cached_data, indent=2)) + + except Exception as e: + print(f"Test error: {e}") + + finally: + await runner.cleanup() + + asyncio.run(run_tests()) \ No newline at end of file diff --git a/ws_proxy.py b/ws_proxy.py new file mode 100644 index 0000000..76536eb --- /dev/null +++ b/ws_proxy.py @@ -0,0 +1,126 @@ +import json +import asyncio +import logging +from typing import Dict, Any, Optional +from aiohttp import web, WSMsgType, ClientSession +from router import Router + + +class WebSocketProxy: + def __init__(self, router: Router): + self.router = router + self.logger = logging.getLogger(__name__) + self.subscription_mappings: Dict[str, str] = {} + + async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.logger.info("New WebSocket connection established") + + provider = self.router.get_next_available_provider() + if not provider: + await ws.close(code=1011, message=b'No available providers') + return ws + + try: + provider_ws = await self._connect_to_provider(provider) + if not provider_ws: + await ws.close(code=1011, message=b'Failed to connect to provider') + return ws + + await asyncio.gather( + self._proxy_client_to_provider(ws, provider_ws, provider), + self._proxy_provider_to_client(provider_ws, ws, provider), + return_exceptions=True + ) + + except Exception as e: + self.logger.error(f"WebSocket proxy error: {e}") + + finally: + if not ws.closed: + await ws.close() + + return ws + + async def _connect_to_provider(self, provider) -> Optional[object]: + try: + session = ClientSession() + ws = await session.ws_connect(provider.ws_url) + self.logger.info(f"Connected to provider {provider.name} WebSocket") + return ws + except Exception as e: + self.logger.error(f"Failed to connect to provider {provider.name}: {e}") + return None + + async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None: + async for msg in client_ws: + if msg.type == WSMsgType.TEXT: + try: + data = json.loads(msg.data) + + transformed_request = provider.transform_request(data) + + await provider_ws.send_str(json.dumps(transformed_request)) + self.logger.debug(f"Forwarded message to {provider.name}: {data.get('method', 'unknown')}") + + except json.JSONDecodeError: + self.logger.warning("Received invalid JSON from client") + except Exception as e: + self.logger.error(f"Error forwarding to provider: {e}") + break + + elif msg.type == WSMsgType.ERROR: + self.logger.error(f'WebSocket error: {client_ws.exception()}') + break + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): + break + + async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None: + async for msg in provider_ws: + if msg.type == WSMsgType.TEXT: + try: + data = json.loads(msg.data) + + transformed_response = provider.transform_response(data) + + if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})): + subscription_id = transformed_response.get("result") + if subscription_id: + self.subscription_mappings[str(subscription_id)] = provider.name + + transformed_response["_cached"] = False + transformed_response["_provider"] = provider.name + + method = transformed_response.get("method", "") + params = transformed_response.get("params", {}) + if method and params: + self.router.cache.set(method, params, transformed_response) + + await client_ws.send_str(json.dumps(transformed_response)) + self.logger.debug(f"Forwarded response from {provider.name}") + + except json.JSONDecodeError: + self.logger.warning(f"Received invalid JSON from provider {provider.name}") + except Exception as e: + self.logger.error(f"Error forwarding from provider: {e}") + break + + elif msg.type == WSMsgType.ERROR: + self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}') + break + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): + break + + +async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse: + router: Router = request.app['router'] + ws_proxy = WebSocketProxy(router) + return await ws_proxy.handle_ws_connection(request) + + +def setup_ws_routes(app: web.Application) -> None: + app.router.add_get('/ws', handle_ws_connection) \ No newline at end of file