solana-proxy/ws_proxy.py
shreerang 75eaba600a Handle CORS and selectively cache responses for appropriate methods (#1)
Part of https://www.notion.so/Laconic-Mainnet-Plan-1eca6b22d47280569cd0d1e6d711d949

Co-authored-by: Shreerang Kale <shreerangkale@gmail.com>
Reviewed-on: #1
Co-authored-by: shreerang <shreerang@noreply.git.vdb.to>
Co-committed-by: shreerang <shreerang@noreply.git.vdb.to>
2025-08-01 10:37:06 +00:00

142 lines
5.7 KiB
Python

import json
import asyncio
import logging
from typing import Dict, Any, Optional
from aiohttp import web, WSMsgType, ClientSession
from router import Router
class WebSocketProxy:
def __init__(self, router: Router):
self.router = router
self.logger = logging.getLogger(__name__)
self.subscription_mappings: Dict[str, str] = {}
async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self.logger.info("New WebSocket connection established")
provider = self.router.get_next_available_provider()
if not provider:
await ws.close(code=1011, message=b'No available providers')
return ws
try:
provider_connection = await self._connect_to_provider(provider)
if not provider_connection:
await ws.close(code=1011, message=b'Failed to connect to provider')
return ws
provider_ws, provider_session = provider_connection
await asyncio.gather(
self._proxy_client_to_provider(ws, provider_ws, provider),
self._proxy_provider_to_client(provider_ws, ws, provider),
return_exceptions=True
)
# Clean up provider connection
if not provider_ws.closed:
await provider_ws.close()
await provider_session.close()
except Exception as e:
self.logger.error(f"WebSocket proxy error: {e}")
finally:
if not ws.closed:
await ws.close()
return ws
async def _connect_to_provider(self, provider) -> Optional[tuple]:
session = None
try:
session = ClientSession()
self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}")
ws = await session.ws_connect(provider.ws_url)
self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}")
return (ws, session)
except Exception as e:
self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}")
if session:
await session.close()
return None
async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None:
async for msg in client_ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
method = data.get('method', 'unknown')
self.logger.debug(f"Received from client: {data}")
transformed_request = provider.transform_request(data)
await provider_ws.send_str(json.dumps(transformed_request))
self.logger.debug(f"Forwarded message to {provider.name}: {method}")
except json.JSONDecodeError:
self.logger.warning("Received invalid JSON from client")
except Exception as e:
self.logger.error(f"Error forwarding to provider: {e}")
break
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'WebSocket error: {client_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
break
async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None:
async for msg in provider_ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
self.logger.debug(f"Received from provider {provider.name}: {data}")
transformed_response = provider.transform_response(data)
if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})):
subscription_id = transformed_response.get("result")
if subscription_id:
self.subscription_mappings[str(subscription_id)] = provider.name
self.logger.debug(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}")
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
await client_ws.send_str(json.dumps(transformed_response))
self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")
except json.JSONDecodeError:
self.logger.warning(f"Received invalid JSON from provider {provider.name}")
except Exception as e:
self.logger.error(f"Error forwarding from provider: {e}")
# Don't break here - continue processing other messages
continue
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
self.logger.warning(f"Provider WebSocket connection closed from {provider.name}")
break
self.logger.warning(f"Provider-to-client message loop ended for {provider.name}")
async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
router: Router = request.app['router']
ws_proxy = WebSocketProxy(router)
return await ws_proxy.handle_ws_connection(request)
def setup_ws_routes(app: web.Application) -> None:
app.router.add_get('/ws', handle_ws_connection)