import json import asyncio import logging from typing import Dict, Any, Optional from aiohttp import web, WSMsgType, ClientSession from router import Router class WebSocketProxy: def __init__(self, router: Router): self.router = router self.logger = logging.getLogger(__name__) self.subscription_mappings: Dict[str, str] = {} async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) self.logger.info("New WebSocket connection established") provider = self.router.get_next_available_provider() if not provider: await ws.close(code=1011, message=b'No available providers') return ws try: provider_ws = await self._connect_to_provider(provider) if not provider_ws: await ws.close(code=1011, message=b'Failed to connect to provider') return ws await asyncio.gather( self._proxy_client_to_provider(ws, provider_ws, provider), self._proxy_provider_to_client(provider_ws, ws, provider), return_exceptions=True ) except Exception as e: self.logger.error(f"WebSocket proxy error: {e}") finally: if not ws.closed: await ws.close() return ws async def _connect_to_provider(self, provider) -> Optional[object]: try: session = ClientSession() ws = await session.ws_connect(provider.ws_url) self.logger.info(f"Connected to provider {provider.name} WebSocket") return ws except Exception as e: self.logger.error(f"Failed to connect to provider {provider.name}: {e}") return None async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None: async for msg in client_ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) transformed_request = provider.transform_request(data) await provider_ws.send_str(json.dumps(transformed_request)) self.logger.debug(f"Forwarded message to {provider.name}: {data.get('method', 'unknown')}") except json.JSONDecodeError: self.logger.warning("Received invalid JSON from client") except Exception as e: self.logger.error(f"Error forwarding to provider: {e}") break elif msg.type == WSMsgType.ERROR: self.logger.error(f'WebSocket error: {client_ws.exception()}') break elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): break async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None: async for msg in provider_ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) transformed_response = provider.transform_response(data) if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})): subscription_id = transformed_response.get("result") if subscription_id: self.subscription_mappings[str(subscription_id)] = provider.name transformed_response["_cached"] = False transformed_response["_provider"] = provider.name method = transformed_response.get("method", "") params = transformed_response.get("params", {}) if method and params: self.router.cache.set(method, params, transformed_response) await client_ws.send_str(json.dumps(transformed_response)) self.logger.debug(f"Forwarded response from {provider.name}") except json.JSONDecodeError: self.logger.warning(f"Received invalid JSON from provider {provider.name}") except Exception as e: self.logger.error(f"Error forwarding from provider: {e}") break elif msg.type == WSMsgType.ERROR: self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}') break elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING): break async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse: router: Router = request.app['router'] ws_proxy = WebSocketProxy(router) return await ws_proxy.handle_ws_connection(request) def setup_ws_routes(app: web.Application) -> None: app.router.add_get('/ws', handle_ws_connection)