Part of https://www.notion.so/Laconic-Mainnet-Plan-1eca6b22d47280569cd0d1e6d711d949 Co-authored-by: Shreerang Kale <shreerangkale@gmail.com> Reviewed-on: #1 Co-authored-by: shreerang <shreerang@noreply.git.vdb.to> Co-committed-by: shreerang <shreerang@noreply.git.vdb.to>
142 lines
5.7 KiB
Python
142 lines
5.7 KiB
Python
import json
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Any, Optional
|
|
from aiohttp import web, WSMsgType, ClientSession
|
|
from router import Router
|
|
|
|
|
|
class WebSocketProxy:
|
|
def __init__(self, router: Router):
|
|
self.router = router
|
|
self.logger = logging.getLogger(__name__)
|
|
self.subscription_mappings: Dict[str, str] = {}
|
|
|
|
async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse:
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
self.logger.info("New WebSocket connection established")
|
|
|
|
provider = self.router.get_next_available_provider()
|
|
if not provider:
|
|
await ws.close(code=1011, message=b'No available providers')
|
|
return ws
|
|
|
|
try:
|
|
provider_connection = await self._connect_to_provider(provider)
|
|
if not provider_connection:
|
|
await ws.close(code=1011, message=b'Failed to connect to provider')
|
|
return ws
|
|
|
|
provider_ws, provider_session = provider_connection
|
|
|
|
await asyncio.gather(
|
|
self._proxy_client_to_provider(ws, provider_ws, provider),
|
|
self._proxy_provider_to_client(provider_ws, ws, provider),
|
|
return_exceptions=True
|
|
)
|
|
|
|
# Clean up provider connection
|
|
if not provider_ws.closed:
|
|
await provider_ws.close()
|
|
await provider_session.close()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"WebSocket proxy error: {e}")
|
|
|
|
finally:
|
|
if not ws.closed:
|
|
await ws.close()
|
|
|
|
return ws
|
|
|
|
async def _connect_to_provider(self, provider) -> Optional[tuple]:
|
|
session = None
|
|
try:
|
|
session = ClientSession()
|
|
self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}")
|
|
ws = await session.ws_connect(provider.ws_url)
|
|
self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}")
|
|
return (ws, session)
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}")
|
|
if session:
|
|
await session.close()
|
|
return None
|
|
|
|
async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None:
|
|
async for msg in client_ws:
|
|
if msg.type == WSMsgType.TEXT:
|
|
try:
|
|
data = json.loads(msg.data)
|
|
method = data.get('method', 'unknown')
|
|
|
|
self.logger.debug(f"Received from client: {data}")
|
|
|
|
transformed_request = provider.transform_request(data)
|
|
|
|
await provider_ws.send_str(json.dumps(transformed_request))
|
|
self.logger.debug(f"Forwarded message to {provider.name}: {method}")
|
|
|
|
except json.JSONDecodeError:
|
|
self.logger.warning("Received invalid JSON from client")
|
|
except Exception as e:
|
|
self.logger.error(f"Error forwarding to provider: {e}")
|
|
break
|
|
|
|
elif msg.type == WSMsgType.ERROR:
|
|
self.logger.error(f'WebSocket error: {client_ws.exception()}')
|
|
break
|
|
|
|
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
|
break
|
|
|
|
async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None:
|
|
async for msg in provider_ws:
|
|
if msg.type == WSMsgType.TEXT:
|
|
try:
|
|
data = json.loads(msg.data)
|
|
self.logger.debug(f"Received from provider {provider.name}: {data}")
|
|
|
|
transformed_response = provider.transform_response(data)
|
|
|
|
if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})):
|
|
subscription_id = transformed_response.get("result")
|
|
if subscription_id:
|
|
self.subscription_mappings[str(subscription_id)] = provider.name
|
|
self.logger.debug(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}")
|
|
|
|
transformed_response["_cached"] = False
|
|
transformed_response["_provider"] = provider.name
|
|
|
|
await client_ws.send_str(json.dumps(transformed_response))
|
|
self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")
|
|
|
|
except json.JSONDecodeError:
|
|
self.logger.warning(f"Received invalid JSON from provider {provider.name}")
|
|
except Exception as e:
|
|
self.logger.error(f"Error forwarding from provider: {e}")
|
|
# Don't break here - continue processing other messages
|
|
continue
|
|
|
|
elif msg.type == WSMsgType.ERROR:
|
|
self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}')
|
|
break
|
|
|
|
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
|
|
self.logger.warning(f"Provider WebSocket connection closed from {provider.name}")
|
|
break
|
|
|
|
self.logger.warning(f"Provider-to-client message loop ended for {provider.name}")
|
|
|
|
|
|
async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
|
|
router: Router = request.app['router']
|
|
ws_proxy = WebSocketProxy(router)
|
|
return await ws_proxy.handle_ws_connection(request)
|
|
|
|
|
|
def setup_ws_routes(app: web.Application) -> None:
|
|
app.router.add_get('/ws', handle_ws_connection)
|