solana-proxy/ws_proxy.py
2025-07-25 16:20:48 +05:30

187 lines
8.3 KiB
Python

import json
import asyncio
import logging
from typing import Dict, Any, Optional
from aiohttp import web, WSMsgType, ClientSession
from router import Router
class WebSocketProxy:
def __init__(self, router: Router):
self.router = router
self.logger = logging.getLogger(__name__)
self.subscription_mappings: Dict[str, str] = {}
async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self.logger.info("New WebSocket connection established")
provider = self.router.get_next_available_provider()
if not provider:
await ws.close(code=1011, message=b'No available providers')
return ws
try:
provider_connection = await self._connect_to_provider(provider)
if not provider_connection:
await ws.close(code=1011, message=b'Failed to connect to provider')
return ws
provider_ws, provider_session = provider_connection
try:
await asyncio.gather(
self._proxy_client_to_provider(ws, provider_ws, provider),
self._proxy_provider_to_client(provider_ws, ws, provider),
return_exceptions=True
)
finally:
# Clean up provider connection
if not provider_ws.closed:
await provider_ws.close()
await provider_session.close()
except Exception as e:
self.logger.error(f"WebSocket proxy error: {e}")
finally:
if not ws.closed:
await ws.close()
return ws
async def _connect_to_provider(self, provider) -> Optional[tuple]:
session = None
try:
session = ClientSession()
self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}")
ws = await session.ws_connect(provider.ws_url)
self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}")
return (ws, session)
except Exception as e:
self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}")
if session:
await session.close()
return None
async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None:
async for msg in client_ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
method = data.get('method', 'unknown')
self.logger.info(f"Received from client: {data}")
# Handle ping messages locally
if method == "ping":
pong_response = {
"jsonrpc": "2.0",
"result": "pong",
"id": data.get("id")
}
await client_ws.send_str(json.dumps(pong_response))
self.logger.info("Responded to ping with pong")
continue
# Special logging for signature subscriptions
if method == "signatureSubscribe":
signature = data.get('params', [None])[0] if data.get('params') else None
self.logger.info(f"SIGNATURE_SUBSCRIBE: Forwarding to {provider.name} for signature: {signature}")
transformed_request = provider.transform_request(data)
if method == "signatureSubscribe":
self.logger.info(f"SIGNATURE_SUBSCRIBE: Sending to provider {provider.name}: {transformed_request}")
await provider_ws.send_str(json.dumps(transformed_request))
self.logger.info(f"Forwarded message to {provider.name}: {method}")
except json.JSONDecodeError:
self.logger.warning("Received invalid JSON from client")
except Exception as e:
self.logger.error(f"Error forwarding to provider: {e}")
break
elif msg.type == WSMsgType.PING:
await client_ws.pong(msg.data)
self.logger.debug("Responded to WebSocket ping with pong")
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'WebSocket error: {client_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
break
async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None:
self.logger.info(f"Starting provider-to-client message loop for {provider.name}")
message_count = 0
async for msg in provider_ws:
message_count += 1
self.logger.info(f"Provider {provider.name} message #{message_count}, type: {msg.type}")
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
self.logger.info(f"Received from provider {provider.name}: {data}")
# Special logging for signature subscription responses
if "result" in data and isinstance(data.get("result"), (int, str)):
self.logger.info(f"SIGNATURE_SUBSCRIBE: Got subscription ID response from {provider.name}: {data.get('result')}")
elif data.get("method") == "signatureNotification":
subscription_id = data.get("params", {}).get("subscription")
result = data.get("params", {}).get("result")
self.logger.info(f"SIGNATURE_NOTIFICATION: From {provider.name}, subscription {subscription_id}, result: {result}")
transformed_response = provider.transform_response(data)
if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})):
subscription_id = transformed_response.get("result")
if subscription_id:
self.subscription_mappings[str(subscription_id)] = provider.name
self.logger.info(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}")
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
method = transformed_response.get("method", "")
params = transformed_response.get("params", {})
if method and params:
self.router.cache.set(method, params, transformed_response)
await client_ws.send_str(json.dumps(transformed_response))
self.logger.info(f"Forwarded response to client from {provider.name}: {transformed_response}")
except json.JSONDecodeError:
self.logger.warning(f"Received invalid JSON from provider {provider.name}")
except Exception as e:
self.logger.error(f"Error forwarding from provider: {e}")
# Don't break here - continue processing other messages
continue
elif msg.type == WSMsgType.PING:
await provider_ws.pong(msg.data)
self.logger.debug(f"Responded to provider WebSocket ping from {provider.name}")
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
self.logger.warning(f"Provider WebSocket connection closed from {provider.name}")
break
self.logger.warning(f"Provider-to-client message loop ended for {provider.name} after {message_count} messages")
async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
router: Router = request.app['router']
ws_proxy = WebSocketProxy(router)
return await ws_proxy.handle_ws_connection(request)
def setup_ws_routes(app: web.Application) -> None:
app.router.add_get('/', handle_ws_connection)