import json import asyncio import logging from typing import Dict, Any, Optional from aiohttp import web, WSMsgType, ClientSession from router import Router class WebSocketProxy: def __init__(self, router: Router): self.router = router self.logger = logging.getLogger(__name__) self.subscription_mappings: Dict[str, str] = {} async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) self.logger.info("New WebSocket connection established") provider = self.router.get_next_available_provider() if not provider: await ws.close(code=1011, message=b'No available providers') return ws try: provider_connection = await self._connect_to_provider(provider) if not provider_connection: await ws.close(code=1011, message=b'Failed to connect to provider') return ws provider_ws, provider_session = provider_connection await asyncio.gather( self._proxy_client_to_provider(ws, provider_ws, provider), self._proxy_provider_to_client(provider_ws, ws, provider), return_exceptions=True ) # Clean up provider connection if not provider_ws.closed: await provider_ws.close() await provider_session.close() except Exception as e: self.logger.error(f"WebSocket proxy error: {e}") finally: if not ws.closed: await ws.close() return ws async def _connect_to_provider(self, provider) -> Optional[tuple]: session = None try: session = ClientSession() self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}") ws = await session.ws_connect(provider.ws_url) self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}") return (ws, session) except Exception as e: self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}") if session: await session.close() return None async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None: async for msg in client_ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) method = data.get('method', 'unknown') self.logger.debug(f"Received from client: {data}") transformed_request = provider.transform_request(data) await provider_ws.send_str(json.dumps(transformed_request)) self.logger.debug(f"Forwarded message to {provider.name}: {method}") except json.JSONDecodeError: self.logger.warning("Received invalid JSON from client") except Exception as e: self.logger.error(f"Error forwarding to provider: {e}") break elif msg.type == WSMsgType.ERROR: self.logger.error(f'WebSocket error: {client_ws.exception()}') break elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): break async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None: async for msg in provider_ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) self.logger.debug(f"Received from provider {provider.name}: {data}") transformed_response = provider.transform_response(data) if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})): subscription_id = transformed_response.get("result") if subscription_id: self.subscription_mappings[str(subscription_id)] = provider.name self.logger.debug(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}") transformed_response["_cached"] = False transformed_response["_provider"] = provider.name await client_ws.send_str(json.dumps(transformed_response)) self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}") except json.JSONDecodeError: self.logger.warning(f"Received invalid JSON from provider {provider.name}") except Exception as e: self.logger.error(f"Error forwarding from provider: {e}") # Don't break here - continue processing other messages continue elif msg.type == WSMsgType.ERROR: self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}') break elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): self.logger.warning(f"Provider WebSocket connection closed from {provider.name}") break self.logger.warning(f"Provider-to-client message loop ended for {provider.name}") async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse: router: Router = request.app['router'] ws_proxy = WebSocketProxy(router) return await ws_proxy.handle_ws_connection(request) def setup_ws_routes(app: web.Application) -> None: app.router.add_get('/ws', handle_ws_connection)