Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python-ecosys/aiohttp ignores 'headers' argument for websocket #940

Open
jomasnash opened this issue Nov 30, 2024 · 5 comments
Open

python-ecosys/aiohttp ignores 'headers' argument for websocket #940

jomasnash opened this issue Nov 30, 2024 · 5 comments

Comments

@jomasnash
Copy link

jomasnash commented Nov 30, 2024

Problem:

Latest version of aiohttp:
When making connection to a Websocket, the header argument is ignored.

The consequence is that you cannot make a connection to most online MQTT broker's over websocket
because they need the header entry: "Sec-WebSocket-Protocol":"mqtt" in the handshake of
the upgrade protocol.

See this small example code:
It connects to a MQTT broker and then sends the CONNECT mqtt packet.
Then it should get a reply of opcode:2, data: b' \x02\x00\x00' where 'data' is a CONNACK mqtt package
Because of the missing header entry "Sec-WebSocket-Protocol":"mqtt" most brokers will refuse the connection or
refuse to accept MQTT packets.

import aiohttp
import asyncio

async def connect():
    url = "ws://test.mosquitto.org:8080"
    headers = {"Sec-WebSocket-Protocol":"mqtt"}
    connect_pkt = bytearray(b'\x10\x10\x00\x04MQTT\x04\x02\x00x\x00\x04fz54')
    
    async with aiohttp.ClientSession(headers=headers).ws_connect(url) as ws:
            print("Connected")
            await ws.send_bytes(connect_pkt)
            opcode, data = await ws.ws.receive()
            print(f"opcode:{opcode}, data{data}")

asyncio.run(connect())
@jomasnash
Copy link
Author

Update:

Changing the following two lines will solve the problem:

In __init__.py (line 266), add 'self._base_headers' as argument

async def _ws_connect(self, url, ssl=None):
    ws_client = WebSocketClient(self._base_headers)  # <--- add self._base_headers

And in aiohttp_ws.py (line 139) change:

async def handshake(self, uri, ssl, req):
    # headers = {}  # <--- replace this
    headers = self.params  # <--- by this

Now the example code will work.

@Carglglz
Copy link
Contributor

Carglglz commented Dec 2, 2024

@jomasnash
Nice finding, I see now I missed this use case while developing the WebSocketClient 😓,
would you like to make a PR with the fix?

@jomasnash
Copy link
Author

@Carglglz
I made a PR. (It is the first time ever I made a PR so I hope it is ok.)

@rambo
Copy link

rambo commented Mar 27, 2025

I wonder if the params is supposed to be possible to contain something else than headers

Another solution is when the handshake is done to pass the clients _base_headers to the request_raw method (the actual raw method should not magically add the base headers).

@rambo
Copy link

rambo commented Mar 27, 2025

And if one does not want to modify aiohttp itself (because Reasons:tm:) we can do a fugly method override on request_raw (which seems to me to not use self for anything so it should in fact be a classmethod)

This also demonstrates messing around with the context manager so we can have multiple tasks sharing this websocket to send messages to the backend (this is rather quickly redacted to protect the guilty and remove things irrelevant to the discussion, if it does not work as-is: fixing it is left as exercise for the reader)

import asyncio
import logging
import ssl
import os

import aiohttp


LOGGER = logging.getLogger(__name__)
CADIR = "cacerts"
BASE_URL = "http://example.com"


def get_ssl_client_context() -> ssl.SSLContext:
    """Load the CA-certs to context and enable verify"""
    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
    for item in os.ilistdir(CADIR):
        name = item[0]
        type = item[1]
        if type != 0x8000:
            LOGGER.warning("Skipping {}, is not file".format(name))
            continue
        if not name.endswith(".pem"):
            LOGGER.warning("Skipping {}, is not .pem".format(name))
            continue
        cafile = f"{CADIR}/{name}"
        LOGGER.info("Loading cert: {}".format(cafile))
        ctx.load_verify_locations(cafile=str(cafile))
    ctx.verify_mode = ssl.CERT_REQUIRED
    return ctx

class WSClient:
    """Handles enrollment etc"""

    _instance: "WSClient" | None = None

    @classmethod
    def singleton(cls) -> "WSClient":
        """Get a singleton"""
        if cls._instance is None:
            cls._instance = WSClient()
        return cls._instance

    def __init__(self) -> None:
        """basic init"""
        self.client = aiohttp.ClientSession(base_url=BASE_URL)
        self.ssl_ctx: ssl.SSLContext | None = None
        if BASE_URL.startswith("https"):
            self.ssl_ctx = get_ssl_client_context()
        self.ws_ctx: aiohttp._WSRequestContextManager | None = None
        self.ws: aiohttp.WebSocketClient | None = None  # FIXME: type
        self._orig_raw_rq = self.client.request_raw
        self.client.request_raw = self._request_raw_wheaders

    def _request_raw_wheaders(
        self,
        method,
        url,
        data=None,
        json=None,
        ssl=None,
        params=None,
        headers={},
        is_handshake=False,
        version=None,
    ) -> (
        asyncio.StreamReader,
        asyncio.StreamWriter,
    ) | asyncio.StreamReader:
        """Wrap the request_raw method to force auth header"""
        if "Authorization" in self.client._base_headers:
            headers["Authorization"] = self.client._base_headers["Authorization"]
        LOGGER.debug("calling orig: url={} headers={}, is_handshake={}, version={}".format(url, headers, is_handshake, version))
        return self._orig_raw_rq(
            method, url, data, json, ssl, params, headers, is_handshake, version
        )

    async def connect_ws(self, force: bool = False) -> bool:
        """Connect to ws_ctx"""
        if self.ws_ctx:
            if not force:
                LOGGER.debug("WS already connected")
                return True
            await self.disconnect_ws()
        ws_url = "ws" + BASE_URL[4:] + "/api/v1/something/ws"
        LOGGER.info("Connecting to {}".format(ws_url))
        # We need to deal with the context managaer in this case
        try:
            self.ws_ctx = self.client.ws_connect(ws_url, ssl=self.ssl_ctx)
            self.ws = await self.ws_ctx.__aenter__()
            return True
        except (OSError, AssertionError) as exc:
            LOGGER.error("ws connection exception {}".format(exc))
            self.ws_ctx = None
            self.ws = None
            return False

    async def disconnect_ws(self) -> None:
        """Do a clean close and contextmanager exit"""
        if not self.ws_ctx:
            LOGGER.debug("No ws_ctx, returning early")
            return
        try:
            await self.ws.close()
        except (OSError, AssertionError) as exc:
            LOGGER.error("ws.close exception {}".format(exc))
        finally:
            self.ws = None
        try:
            await self.ws_ctx.__aexit__(None, None, None)
        except (OSError, AssertionError) as exc:
            LOGGER.error("ws_ctx.__aexit__ exception {}".format(exc))
        finally:
            self.ws_ctx = None

    async def ws_send(self, payload: dict) -> bool:
        """Send payload to websocket, returns False on failure"""
        if not await self.connect_ws():
            return False
        try:
            await self.ws.send_json(payload)
            return True
        except (OSError, AssertionError) as exc:
            LOGGER.error("ws.send_str exception {}".format(exc))
            await self.disconnect_ws()
        return False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants