import argparse
import asyncio
import os
import logging
import aiohttp
from aiohttp import web

# -----------------------------------------------------------------------------
# 4MB is the default inside aiohttp
# -----------------------------------------------------------------------------

MAX_MSG_SIZE = int(os.environ.get("WSLINK_MAX_MSG_SIZE", 4194304))
HEART_BEAT = int(os.environ.get("WSLINK_HEART_BEAT", 30))  # 30 seconds

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# -----------------------------------------------------------------------------
# Helper classes
# -----------------------------------------------------------------------------


class WsClientConnection:
    def __init__(self, propagate_disconnect=True):
        self._url = None
        self._session = None
        self._ws = None
        self._connected = 0
        self._destination = None
        self._ready = asyncio.get_running_loop().create_future()
        self.propagate_disconnect = propagate_disconnect

    def bind(self, value):
        self._destination = value

    @property
    def ready(self):
        return self._ready

    async def connect(self, url):
        logger.debug("client::connect::%s", url)
        self._url = url
        if self._session is None:
            async with aiohttp.ClientSession() as session:
                logger.debug("client::connect::session")
                self._session = session
                try:
                    async with session.ws_connect(self._url) as ws:
                        logger.debug("client::connect::ws")
                        self._ws = ws
                        self._connected += 1
                        self._ready.set_result(True)
                        async for msg in ws:
                            logger.debug("client::connect::ws::msg")
                            if self._connected < 1:
                                logger.debug("client::connect::ws::msg::disconnect")
                                break

                            if self._destination:
                                logger.debug("client::connect::ws::msg::send")
                                await self._destination.send(msg)
                            else:
                                logger.error("ws-client: No destination for message")

                        # Disconnect
                        self.disconnect()

                        # Cleanup connection
                        if not self._ws.closed:
                            await self._ws.close()
                        self._ws = None
                        self._session = None
                finally:
                    self._ready.set_result(False)

        logger.debug("client::connect::exit")

    async def send(self, msg):
        if self._connected > 0 and not self._ws.closed:
            logger.debug("client::send")
            if msg.type == aiohttp.WSMsgType.TEXT:
                await self._ws.send_str(msg.data)
            elif msg.type == aiohttp.WSMsgType.BINARY:
                await self._ws.send_bytes(msg.data)
            elif msg.type == aiohttp.WSMsgType.PING:
                await self._ws.ping(msg.data)
            elif msg.type == aiohttp.WSMsgType.PONG:
                await self._ws.pong(msg.data)
            elif msg.type == aiohttp.WSMsgType.CLOSE:
                await self._ws.close()
            else:
                logger.error("Invalid message to forward")
        else:
            logger.error("client::send - NO SEND")
            logger.error("%s - %s", self._connected, self._ws.closed)
            logger.error("-" * 60)

    def disconnect(self):
        logger.debug("client::disconnect %s", self._connected)
        if self._connected > 0:
            self._connected = -1
            if self._destination and self.propagate_disconnect:
                self._destination.disconnect()

    async def close(self):
        if self._ws is not None:
            await self._ws.close()


# -----------------------------------------------------------------------------


class WsServerConnection:
    def __init__(self, propagate_disconnect=True):
        self._ws = None
        self._destination = None
        self._connected = 0
        self.propagate_disconnect = propagate_disconnect

    def bind(self, value):
        self._destination = value

    async def connect(self, request):
        logger.debug("server::connect")
        self._ws = web.WebSocketResponse(
            max_msg_size=MAX_MSG_SIZE, heartbeat=HEART_BEAT
        )
        await self._ws.prepare(request)
        logger.debug("server::connect::prepare")
        self._connected += 1

        if self._connected > 0:
            async for msg in self._ws:
                logger.debug("server::connect::ws::msg")
                if self._connected < 1:
                    break

                if self._destination:
                    logger.debug("server::connect::ws::msg::send-begin")
                    await self._destination.send(msg)
                    logger.debug("server::connect::ws::msg::send-end")
                else:
                    logger.error("ws-server: No destination for message")

        # Disconnect
        self.disconnect()

        # Cleanup connection
        if not self._ws.closed:
            await self._ws.close()
        self._ws = None

        logger.debug("server::connect::exit")

    async def send(self, msg):
        if self._connected > 0 and not self._ws.closed:
            logger.debug("server::send")
            if msg.type == aiohttp.WSMsgType.TEXT:
                await self._ws.send_str(msg.data)
            elif msg.type == aiohttp.WSMsgType.BINARY:
                await self._ws.send_bytes(msg.data)
            elif msg.type == aiohttp.WSMsgType.PING:
                await self._ws.ping(msg.data)
            elif msg.type == aiohttp.WSMsgType.PONG:
                await self._ws.pong(msg.data)
            elif msg.type == aiohttp.WSMsgType.CLOSE:
                await self._ws.close()
            else:
                logger.error("Invalid message to forward")
        else:
            logger.error("server::send - NO SEND")
            logger.error("%s - %s", self._connected, self._ws.closed)
            logger.error("-" * 60)

    def disconnect(self):
        logger.debug("server::disconnect %s", self._connected)
        if self._connected > 0:
            self._connected = -1
            if self._destination and self.propagate_disconnect:
                self._destination.disconnect()

    async def close(self):
        if self._ws is not None:
            await self._ws.close()


# -----------------------------------------------------------------------------


class ForwardConnection:
    def __init__(self, request, url):
        self._req = request
        self._url = url
        self._ws_client = WsClientConnection()
        self._ws_server = WsServerConnection()
        self._ws_server.bind(self._ws_client)
        self._ws_client.bind(self._ws_server)

    async def connect(self):
        task = asyncio.create_task(self._ws_client.connect(self._url))
        task.add_done_callback(lambda *args, **kwargs: self._ws_server.disconnect())
        await self._ws_client.ready
        await self._ws_server.connect(self._req)

    def disconnect(self):
        self._ws_client.disconnect()
        self._ws_server.disconnect()


class SinkConnection:
    def __init__(self, request):
        self._process_req = request
        self._client_req = None
        self._process_ws = None
        self._client_ws = None

    def can_handle(self, request):
        if self._process_req == request:
            return True

        if self._client_req == request:
            return True

        if self._client_ws is None:
            return True

        return False

    async def connect(self, request):
        if self._process_req == request:
            # First connection is the actual server. Cannot reconnect.
            self._process_ws = WsServerConnection()
            await self._process_ws.connect(request)
            if self._client_ws is not None:
                await self._client_ws.close()

            return True
        elif self._client_req is None:
            # Second connection is the browser. Can reconnect.
            self._client_req = request
            self._client_ws = WsServerConnection(propagate_disconnect=False)
            self._client_ws.bind(self._process_ws)
            self._process_ws.bind(self._client_ws)
            await self._client_ws.connect(request)

            self._client_ws.bind(None)
            self._process_ws.bind(None)

            self._client_req = None
            self._client_ws = None

        return False


# -----------------------------------------------------------------------------
# Handlers
# -----------------------------------------------------------------------------


async def _root_handler(request):
    if request.query_string:
        return web.HTTPFound(f"index.html?{request.query_string}")
    return web.HTTPFound("index.html")


# -----------------------------------------------------------------------------


class WsHandler:
    def __init__(self):
        self._forward_map = {}
        self._relay_map = {}

    def get_handler(self, mode="forward"):
        logger.info("get_handler %s", mode)
        if mode == "forward":
            return self.forward_connect
        if mode == "relay":
            return self.relay_connect
        logger.error("No handler !!!")

    # -----------------------------
    # forward infrastructure
    # -----------------------------

    async def forward_connect(self, request):
        host = request.match_info.get("host", "localhost")
        port = int(request.match_info.get("port", "1234"))
        path = request.match_info.get("path", "ws")
        target_url = f"ws://{host}:{port}/{path}"
        logger.info("=> %s", target_url)

        if target_url in self._forward_map:
            raise web.HTTPForbidden()

        forwarder = ForwardConnection(request, target_url)
        self._forward_map[target_url] = forwarder
        await forwarder.connect()
        self._forward_map.pop(target_url)

    # -----------------------------
    # relay server infrastructure
    # -----------------------------

    async def relay_connect(self, request):
        id = request.path

        if id not in self._relay_map:
            handler = SinkConnection(request)
            self._relay_map[id] = handler

        handler = self._relay_map[id]
        if not handler.can_handle(request):
            raise web.HTTPForbidden()

        if await handler.connect(request):
            # Only pop when the server dies
            self._relay_map.pop(id)


# -----------------------------------------------------------------------------
# Executable
# -----------------------------------------------------------------------------


def main(host=None, port=None, www_path=None, proxy_route=None, mode=None):
    wsRelay = WsHandler()

    # Handle CLI
    parser = argparse.ArgumentParser(
        description="Start ws relay with static content delivery",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help="the interface for the web-server to listen on (default: 0.0.0.0)",
        dest="host",
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=8080,
        help="port number for the web-server to listen on (default: 8080)",
        dest="port",
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="forward",
        help="Working mode [forward, relay] (default: forward)",
    )
    parser.add_argument("--www", type=str, help="Directory to serve", dest="www_path")
    parser.add_argument(
        "--proxy-route",
        type=str,
        help="Proxy URL pattern (default: /proxy/{port}) mode::forward(ws://{host=localhost}:{port=1234}/{path=ws})",
        default="/proxy/{port}",
        dest="proxy_route",
    )
    args, _ = parser.parse_known_args()

    if host is None:
        host = args.host

    if port is None:
        port = args.port

    if mode is None:
        mode = args.mode

    if www_path is None:
        www_path = args.www_path

    if proxy_route is None:
        proxy_route = args.proxy_route

    # Manage routes
    routes = []

    # Need to be first: static delivery should be a fallback
    if proxy_route is not None:
        logger.info("Proxy route: %s", proxy_route)
        routes.append(web.get(proxy_route, wsRelay.get_handler(mode)))

    # Serve static content
    if www_path is not None:
        logger.info("WWW: %s", www_path)
        routes.append(web.get("/", _root_handler))
        routes.append(web.static("/", www_path))

    # Setup web app
    logger.info("Starting relay server: %s %s", host, port)
    web_app = web.Application()
    web_app.add_routes(routes)
    web.run_app(web_app, host=host, port=port)


# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    main()
