Add fastapi code
This commit is contained in:
114
venv/lib/python3.11/site-packages/websockets/__init__.py
Normal file
114
venv/lib/python3.11/site-packages/websockets/__init__.py
Normal file
@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .imports import lazy_import
|
||||
from .version import version as __version__ # noqa
|
||||
|
||||
|
||||
__all__ = [ # noqa
|
||||
"AbortHandshake",
|
||||
"basic_auth_protocol_factory",
|
||||
"BasicAuthWebSocketServerProtocol",
|
||||
"broadcast",
|
||||
"ClientConnection",
|
||||
"connect",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedError",
|
||||
"ConnectionClosedOK",
|
||||
"Data",
|
||||
"DuplicateParameter",
|
||||
"ExtensionName",
|
||||
"ExtensionParameter",
|
||||
"InvalidHandshake",
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidMessage",
|
||||
"InvalidOrigin",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"InvalidState",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"InvalidUpgrade",
|
||||
"InvalidURI",
|
||||
"LoggerLike",
|
||||
"NegotiationError",
|
||||
"Origin",
|
||||
"parse_uri",
|
||||
"PayloadTooBig",
|
||||
"ProtocolError",
|
||||
"RedirectHandshake",
|
||||
"SecurityError",
|
||||
"serve",
|
||||
"ServerConnection",
|
||||
"Subprotocol",
|
||||
"unix_connect",
|
||||
"unix_serve",
|
||||
"WebSocketClientProtocol",
|
||||
"WebSocketCommonProtocol",
|
||||
"WebSocketException",
|
||||
"WebSocketProtocolError",
|
||||
"WebSocketServer",
|
||||
"WebSocketServerProtocol",
|
||||
"WebSocketURI",
|
||||
]
|
||||
|
||||
lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
"auth": ".legacy",
|
||||
"basic_auth_protocol_factory": ".legacy.auth",
|
||||
"BasicAuthWebSocketServerProtocol": ".legacy.auth",
|
||||
"broadcast": ".legacy.protocol",
|
||||
"ClientConnection": ".client",
|
||||
"connect": ".legacy.client",
|
||||
"unix_connect": ".legacy.client",
|
||||
"WebSocketClientProtocol": ".legacy.client",
|
||||
"Headers": ".datastructures",
|
||||
"MultipleValuesError": ".datastructures",
|
||||
"WebSocketException": ".exceptions",
|
||||
"ConnectionClosed": ".exceptions",
|
||||
"ConnectionClosedError": ".exceptions",
|
||||
"ConnectionClosedOK": ".exceptions",
|
||||
"InvalidHandshake": ".exceptions",
|
||||
"SecurityError": ".exceptions",
|
||||
"InvalidMessage": ".exceptions",
|
||||
"InvalidHeader": ".exceptions",
|
||||
"InvalidHeaderFormat": ".exceptions",
|
||||
"InvalidHeaderValue": ".exceptions",
|
||||
"InvalidOrigin": ".exceptions",
|
||||
"InvalidUpgrade": ".exceptions",
|
||||
"InvalidStatus": ".exceptions",
|
||||
"InvalidStatusCode": ".exceptions",
|
||||
"NegotiationError": ".exceptions",
|
||||
"DuplicateParameter": ".exceptions",
|
||||
"InvalidParameterName": ".exceptions",
|
||||
"InvalidParameterValue": ".exceptions",
|
||||
"AbortHandshake": ".exceptions",
|
||||
"RedirectHandshake": ".exceptions",
|
||||
"InvalidState": ".exceptions",
|
||||
"InvalidURI": ".exceptions",
|
||||
"PayloadTooBig": ".exceptions",
|
||||
"ProtocolError": ".exceptions",
|
||||
"WebSocketProtocolError": ".exceptions",
|
||||
"protocol": ".legacy",
|
||||
"WebSocketCommonProtocol": ".legacy.protocol",
|
||||
"ServerConnection": ".server",
|
||||
"serve": ".legacy.server",
|
||||
"unix_serve": ".legacy.server",
|
||||
"WebSocketServerProtocol": ".legacy.server",
|
||||
"WebSocketServer": ".legacy.server",
|
||||
"Data": ".typing",
|
||||
"LoggerLike": ".typing",
|
||||
"Origin": ".typing",
|
||||
"ExtensionHeader": ".typing",
|
||||
"ExtensionParameter": ".typing",
|
||||
"Subprotocol": ".typing",
|
||||
},
|
||||
deprecated_aliases={
|
||||
"framing": ".legacy",
|
||||
"handshake": ".legacy",
|
||||
"parse_uri": ".uri",
|
||||
"WebSocketURI": ".uri",
|
||||
},
|
||||
)
|
||||
230
venv/lib/python3.11/site-packages/websockets/__main__.py
Normal file
230
venv/lib/python3.11/site-packages/websockets/__main__.py
Normal file
@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Set
|
||||
|
||||
from .exceptions import ConnectionClosed
|
||||
from .frames import Close
|
||||
from .legacy.client import connect
|
||||
from .version import version as websockets_version
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
|
||||
def win_enable_vt100() -> None:
|
||||
"""
|
||||
Enable VT-100 for console output on Windows.
|
||||
|
||||
See also https://bugs.python.org/issue29059.
|
||||
|
||||
"""
|
||||
import ctypes
|
||||
|
||||
STD_OUTPUT_HANDLE = ctypes.c_uint(-11)
|
||||
INVALID_HANDLE_VALUE = ctypes.c_uint(-1)
|
||||
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x004
|
||||
|
||||
handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE)
|
||||
if handle == INVALID_HANDLE_VALUE:
|
||||
raise RuntimeError("unable to obtain stdout handle")
|
||||
|
||||
cur_mode = ctypes.c_uint()
|
||||
if ctypes.windll.kernel32.GetConsoleMode(handle, ctypes.byref(cur_mode)) == 0:
|
||||
raise RuntimeError("unable to query current console mode")
|
||||
|
||||
# ctypes ints lack support for the required bit-OR operation.
|
||||
# Temporarily convert to Py int, do the OR and convert back.
|
||||
py_int_mode = int.from_bytes(cur_mode, sys.byteorder)
|
||||
new_mode = ctypes.c_uint(py_int_mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)
|
||||
|
||||
if ctypes.windll.kernel32.SetConsoleMode(handle, new_mode) == 0:
|
||||
raise RuntimeError("unable to set console mode")
|
||||
|
||||
|
||||
def exit_from_event_loop_thread(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
stop: asyncio.Future[None],
|
||||
) -> None:
|
||||
loop.stop()
|
||||
if not stop.done():
|
||||
# When exiting the thread that runs the event loop, raise
|
||||
# KeyboardInterrupt in the main thread to exit the program.
|
||||
if sys.platform == "win32":
|
||||
ctrl_c = signal.CTRL_C_EVENT
|
||||
else:
|
||||
ctrl_c = signal.SIGINT
|
||||
os.kill(os.getpid(), ctrl_c)
|
||||
|
||||
|
||||
def print_during_input(string: str) -> None:
|
||||
sys.stdout.write(
|
||||
# Save cursor position
|
||||
"\N{ESC}7"
|
||||
# Add a new line
|
||||
"\N{LINE FEED}"
|
||||
# Move cursor up
|
||||
"\N{ESC}[A"
|
||||
# Insert blank line, scroll last line down
|
||||
"\N{ESC}[L"
|
||||
# Print string in the inserted blank line
|
||||
f"{string}\N{LINE FEED}"
|
||||
# Restore cursor position
|
||||
"\N{ESC}8"
|
||||
# Move cursor down
|
||||
"\N{ESC}[B"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def print_over_input(string: str) -> None:
|
||||
sys.stdout.write(
|
||||
# Move cursor to beginning of line
|
||||
"\N{CARRIAGE RETURN}"
|
||||
# Delete current line
|
||||
"\N{ESC}[K"
|
||||
# Print string
|
||||
f"{string}\N{LINE FEED}"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
async def run_client(
|
||||
uri: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
inputs: asyncio.Queue[str],
|
||||
stop: asyncio.Future[None],
|
||||
) -> None:
|
||||
try:
|
||||
websocket = await connect(uri)
|
||||
except Exception as exc:
|
||||
print_over_input(f"Failed to connect to {uri}: {exc}.")
|
||||
exit_from_event_loop_thread(loop, stop)
|
||||
return
|
||||
else:
|
||||
print_during_input(f"Connected to {uri}.")
|
||||
|
||||
try:
|
||||
while True:
|
||||
incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv())
|
||||
outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get())
|
||||
done: Set[asyncio.Future[Any]]
|
||||
pending: Set[asyncio.Future[Any]]
|
||||
done, pending = await asyncio.wait(
|
||||
[incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cancel pending tasks to avoid leaking them.
|
||||
if incoming in pending:
|
||||
incoming.cancel()
|
||||
if outgoing in pending:
|
||||
outgoing.cancel()
|
||||
|
||||
if incoming in done:
|
||||
try:
|
||||
message = incoming.result()
|
||||
except ConnectionClosed:
|
||||
break
|
||||
else:
|
||||
if isinstance(message, str):
|
||||
print_during_input("< " + message)
|
||||
else:
|
||||
print_during_input("< (binary) " + message.hex())
|
||||
|
||||
if outgoing in done:
|
||||
message = outgoing.result()
|
||||
await websocket.send(message)
|
||||
|
||||
if stop in done:
|
||||
break
|
||||
|
||||
finally:
|
||||
await websocket.close()
|
||||
assert websocket.close_code is not None and websocket.close_reason is not None
|
||||
close_status = Close(websocket.close_code, websocket.close_reason)
|
||||
|
||||
print_over_input(f"Connection closed: {close_status}.")
|
||||
|
||||
exit_from_event_loop_thread(loop, stop)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Parse command line arguments.
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python -m websockets",
|
||||
description="Interactive WebSocket client.",
|
||||
add_help=False,
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--version", action="store_true")
|
||||
group.add_argument("uri", metavar="<uri>", nargs="?")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
print(f"websockets {websockets_version}")
|
||||
return
|
||||
|
||||
if args.uri is None:
|
||||
parser.error("the following arguments are required: <uri>")
|
||||
|
||||
# If we're on Windows, enable VT100 terminal support.
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
win_enable_vt100()
|
||||
except RuntimeError as exc:
|
||||
sys.stderr.write(
|
||||
f"Unable to set terminal to VT100 mode. This is only "
|
||||
f"supported since Win10 anniversary update. Expect "
|
||||
f"weird symbols on the terminal.\nError: {exc}\n"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
|
||||
try:
|
||||
import readline # noqa
|
||||
except ImportError: # Windows has no `readline` normally
|
||||
pass
|
||||
|
||||
# Create an event loop that will run in a background thread.
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# Due to zealous removal of the loop parameter in the Queue constructor,
|
||||
# we need a factory coroutine to run in the freshly created event loop.
|
||||
async def queue_factory() -> asyncio.Queue[str]:
|
||||
return asyncio.Queue()
|
||||
|
||||
# Create a queue of user inputs. There's no need to limit its size.
|
||||
inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory())
|
||||
|
||||
# Create a stop condition when receiving SIGINT or SIGTERM.
|
||||
stop: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
# Schedule the task that will manage the connection.
|
||||
loop.create_task(run_client(args.uri, loop, inputs, stop))
|
||||
|
||||
# Start the event loop in a background thread.
|
||||
thread = threading.Thread(target=loop.run_forever)
|
||||
thread.start()
|
||||
|
||||
# Read from stdin in the main thread in order to receive signals.
|
||||
try:
|
||||
while True:
|
||||
# Since there's no size limit, put_nowait is identical to put.
|
||||
message = input("> ")
|
||||
loop.call_soon_threadsafe(inputs.put_nowait, message)
|
||||
except (KeyboardInterrupt, EOFError): # ^C, ^D
|
||||
loop.call_soon_threadsafe(stop.set_result, None)
|
||||
|
||||
# Wait for the event loop to terminate.
|
||||
thread.join()
|
||||
|
||||
# For reasons unclear, even though the loop is closed in the thread,
|
||||
# it still thinks it's running here.
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4
venv/lib/python3.11/site-packages/websockets/auth.py
Normal file
4
venv/lib/python3.11/site-packages/websockets/auth.py
Normal file
@ -0,0 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
from .legacy.auth import * # noqa
|
||||
344
venv/lib/python3.11/site-packages/websockets/client.py
Normal file
344
venv/lib/python3.11/site-packages/websockets/client.py
Normal file
@ -0,0 +1,344 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, List, Optional, Sequence
|
||||
|
||||
from .connection import CLIENT, CONNECTING, OPEN, Connection, State
|
||||
from .datastructures import Headers, MultipleValuesError
|
||||
from .exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
InvalidStatus,
|
||||
InvalidUpgrade,
|
||||
NegotiationError,
|
||||
)
|
||||
from .extensions import ClientExtensionFactory, Extension
|
||||
from .headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_connection,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
parse_upgrade,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
LoggerLike,
|
||||
Origin,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
from .uri import WebSocketURI
|
||||
from .utils import accept_key, generate_key
|
||||
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
from .legacy.client import * # isort:skip # noqa
|
||||
|
||||
|
||||
__all__ = ["ClientConnection"]
|
||||
|
||||
|
||||
class ClientConnection(Connection):
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket client connection.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server, parsed
|
||||
with :func:`~websockets.uri.parse_uri`.
|
||||
origin: value of the ``Origin`` header. This is useful when connecting
|
||||
to a server that validates the ``Origin`` header to defend against
|
||||
Cross-Site WebSocket Hijacking attacks.
|
||||
extensions: list of supported extensions, in order in which they
|
||||
should be tried.
|
||||
subprotocols: list of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
state: initial state of the WebSocket connection.
|
||||
max_size: maximum size of incoming messages in bytes;
|
||||
:obj:`None` to disable the limit.
|
||||
logger: logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.client")``;
|
||||
see the :doc:`logging guide <../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
origin: Optional[Origin] = None,
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
state: State = CONNECTING,
|
||||
max_size: Optional[int] = 2**20,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
):
|
||||
super().__init__(
|
||||
side=CLIENT,
|
||||
state=state,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
self.wsuri = wsuri
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.key = generate_key()
|
||||
|
||||
def connect(self) -> Request: # noqa: F811
|
||||
"""
|
||||
Create a handshake request to open a connection.
|
||||
|
||||
You must send the handshake request with :meth:`send_request`.
|
||||
|
||||
You can modify it before sending it, for example to add HTTP headers.
|
||||
|
||||
Returns:
|
||||
Request: WebSocket handshake request event to send to the server.
|
||||
|
||||
"""
|
||||
headers = Headers()
|
||||
|
||||
headers["Host"] = build_host(
|
||||
self.wsuri.host, self.wsuri.port, self.wsuri.secure
|
||||
)
|
||||
|
||||
if self.wsuri.user_info:
|
||||
headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info)
|
||||
|
||||
if self.origin is not None:
|
||||
headers["Origin"] = self.origin
|
||||
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = self.key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
|
||||
if self.available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in self.available_extensions
|
||||
]
|
||||
)
|
||||
headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if self.available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(self.available_subprotocols)
|
||||
headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
return Request(self.wsuri.resource_name, headers)
|
||||
|
||||
def process_response(self, response: Response) -> None:
|
||||
"""
|
||||
Check a handshake response.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake response received from the server.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: if the handshake response is invalid.
|
||||
|
||||
"""
|
||||
|
||||
if response.status_code != 101:
|
||||
raise InvalidStatus(response)
|
||||
|
||||
headers = response.headers
|
||||
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade(
|
||||
"Connection", ", ".join(connection) if connection else None
|
||||
)
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. It's supposed to be 'WebSocket'.
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Accept",
|
||||
"more than one Sec-WebSocket-Accept header found",
|
||||
) from exc
|
||||
|
||||
if s_w_accept != accept_key(self.key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
|
||||
self.extensions = self.process_extensions(headers)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(headers)
|
||||
|
||||
def process_extensions(self, headers: Headers) -> List[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake response headers.
|
||||
|
||||
Returns:
|
||||
List[Extension]: List of accepted extensions.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: to abort the handshake.
|
||||
|
||||
"""
|
||||
accepted_extensions: List[Extension] = []
|
||||
|
||||
extensions = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if extensions:
|
||||
|
||||
if self.available_extensions is None:
|
||||
raise InvalidHandshake("no extensions supported")
|
||||
|
||||
parsed_extensions: List[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in extensions], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_extensions:
|
||||
|
||||
for extension_factory in self.available_extensions:
|
||||
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
If provided, check that it contains exactly one supported subprotocol.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake response headers.
|
||||
|
||||
Returns:
|
||||
Optional[Subprotocol]: Subprotocol, if one was selected.
|
||||
|
||||
"""
|
||||
subprotocol: Optional[Subprotocol] = None
|
||||
|
||||
subprotocols = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if subprotocols:
|
||||
|
||||
if self.available_subprotocols is None:
|
||||
raise InvalidHandshake("no subprotocols supported")
|
||||
|
||||
parsed_subprotocols: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in subprotocols], []
|
||||
)
|
||||
|
||||
if len(parsed_subprotocols) > 1:
|
||||
subprotocols_display = ", ".join(parsed_subprotocols)
|
||||
raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")
|
||||
|
||||
subprotocol = parsed_subprotocols[0]
|
||||
|
||||
if subprotocol not in self.available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
def send_request(self, request: Request) -> None:
|
||||
"""
|
||||
Send a handshake request to the server.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request event.
|
||||
|
||||
"""
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", request.path)
|
||||
for key, value in request.headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
self.writes.append(request.serialize())
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
if self.state is CONNECTING:
|
||||
response = yield from Response.parse(
|
||||
self.reader.read_line,
|
||||
self.reader.read_exact,
|
||||
self.reader.read_to_eof,
|
||||
)
|
||||
|
||||
if self.debug:
|
||||
code, phrase = response.status_code, response.reason_phrase
|
||||
self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
|
||||
for key, value in response.headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
if response.body is not None:
|
||||
self.logger.debug("< [body] (%d bytes)", len(response.body))
|
||||
|
||||
try:
|
||||
self.process_response(response)
|
||||
except InvalidHandshake as exc:
|
||||
response._exception = exc
|
||||
self.handshake_exc = exc
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
else:
|
||||
assert self.state is CONNECTING
|
||||
self.state = OPEN
|
||||
finally:
|
||||
self.events.append(response)
|
||||
|
||||
yield from super().parse()
|
||||
702
venv/lib/python3.11/site-packages/websockets/connection.py
Normal file
702
venv/lib/python3.11/site-packages/websockets/connection.py
Normal file
@ -0,0 +1,702 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Generator, List, Optional, Type, Union
|
||||
|
||||
from .exceptions import (
|
||||
ConnectionClosed,
|
||||
ConnectionClosedError,
|
||||
ConnectionClosedOK,
|
||||
InvalidState,
|
||||
PayloadTooBig,
|
||||
ProtocolError,
|
||||
)
|
||||
from .extensions import Extension
|
||||
from .frames import (
|
||||
OK_CLOSE_CODES,
|
||||
OP_BINARY,
|
||||
OP_CLOSE,
|
||||
OP_CONT,
|
||||
OP_PING,
|
||||
OP_PONG,
|
||||
OP_TEXT,
|
||||
Close,
|
||||
Frame,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .streams import StreamReader
|
||||
from .typing import LoggerLike, Origin, Subprotocol
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Connection",
|
||||
"Side",
|
||||
"State",
|
||||
"SEND_EOF",
|
||||
]
|
||||
|
||||
Event = Union[Request, Response, Frame]
|
||||
"""Events that :meth:`~Connection.events_received` may return."""
|
||||
|
||||
|
||||
class Side(enum.IntEnum):
|
||||
"""A WebSocket connection is either a server or a client."""
|
||||
|
||||
SERVER, CLIENT = range(2)
|
||||
|
||||
|
||||
SERVER = Side.SERVER
|
||||
CLIENT = Side.CLIENT
|
||||
|
||||
|
||||
class State(enum.IntEnum):
|
||||
"""A WebSocket connection is in one of these four states."""
|
||||
|
||||
CONNECTING, OPEN, CLOSING, CLOSED = range(4)
|
||||
|
||||
|
||||
CONNECTING = State.CONNECTING
|
||||
OPEN = State.OPEN
|
||||
CLOSING = State.CLOSING
|
||||
CLOSED = State.CLOSED
|
||||
|
||||
|
||||
SEND_EOF = b""
|
||||
"""Sentinel signaling that the TCP connection must be half-closed."""
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket connection.
|
||||
|
||||
Args:
|
||||
side: :attr:`~Side.CLIENT` or :attr:`~Side.SERVER`.
|
||||
state: initial state of the WebSocket connection.
|
||||
max_size: maximum size of incoming messages in bytes;
|
||||
:obj:`None` to disable the limit.
|
||||
logger: logger for this connection; depending on ``side``,
|
||||
defaults to ``logging.getLogger("websockets.client")``
|
||||
or ``logging.getLogger("websockets.server")``;
|
||||
see the :doc:`logging guide <../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
side: Side,
|
||||
state: State = OPEN,
|
||||
max_size: Optional[int] = 2**20,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
) -> None:
|
||||
# Unique identifier. For logs.
|
||||
self.id: uuid.UUID = uuid.uuid4()
|
||||
"""Unique identifier of the connection. Useful in logs."""
|
||||
|
||||
# Logger or LoggerAdapter for this connection.
|
||||
if logger is None:
|
||||
logger = logging.getLogger(f"websockets.{side.name.lower()}")
|
||||
self.logger: LoggerLike = logger
|
||||
"""Logger for this connection."""
|
||||
|
||||
# Track if DEBUG is enabled. Shortcut logging calls if it isn't.
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
|
||||
# Connection side. CLIENT or SERVER.
|
||||
self.side = side
|
||||
|
||||
# Connection state. Initially OPEN because subclasses handle CONNECTING.
|
||||
self.state = state
|
||||
|
||||
# Maximum size of incoming messages in bytes.
|
||||
self.max_size = max_size
|
||||
|
||||
# Current size of incoming message in bytes. Only set while reading a
|
||||
# fragmented message i.e. a data frames with the FIN bit not set.
|
||||
self.cur_size: Optional[int] = None
|
||||
|
||||
# True while sending a fragmented message i.e. a data frames with the
|
||||
# FIN bit not set.
|
||||
self.expect_continuation_frame = False
|
||||
|
||||
# WebSocket protocol parameters.
|
||||
self.origin: Optional[Origin] = None
|
||||
self.extensions: List[Extension] = []
|
||||
self.subprotocol: Optional[Subprotocol] = None
|
||||
|
||||
# Close code and reason, set when a close frame is sent or received.
|
||||
self.close_rcvd: Optional[Close] = None
|
||||
self.close_sent: Optional[Close] = None
|
||||
self.close_rcvd_then_sent: Optional[bool] = None
|
||||
|
||||
# Track if an exception happened during the handshake.
|
||||
self.handshake_exc: Optional[Exception] = None
|
||||
"""
|
||||
Exception to raise if the opening handshake failed.
|
||||
|
||||
:obj:`None` if the opening handshake succeeded.
|
||||
|
||||
"""
|
||||
|
||||
# Track if send_eof() was called.
|
||||
self.eof_sent = False
|
||||
|
||||
# Parser state.
|
||||
self.reader = StreamReader()
|
||||
self.events: List[Event] = []
|
||||
self.writes: List[bytes] = []
|
||||
self.parser = self.parse()
|
||||
next(self.parser) # start coroutine
|
||||
self.parser_exc: Optional[Exception] = None
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
"""
|
||||
WebSocket connection state.
|
||||
|
||||
Defined in 4.1, 4.2, 7.1.3, and 7.1.4 of :rfc:`6455`.
|
||||
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: State) -> None:
|
||||
if self.debug:
|
||||
self.logger.debug("= connection is %s", state.name)
|
||||
self._state = state
|
||||
|
||||
@property
|
||||
def close_code(self) -> Optional[int]:
|
||||
"""
|
||||
`WebSocket close code`_.
|
||||
|
||||
.. _WebSocket close code:
|
||||
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5
|
||||
|
||||
:obj:`None` if the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
if self.state is not CLOSED:
|
||||
return None
|
||||
elif self.close_rcvd is None:
|
||||
return 1006
|
||||
else:
|
||||
return self.close_rcvd.code
|
||||
|
||||
@property
|
||||
def close_reason(self) -> Optional[str]:
|
||||
"""
|
||||
`WebSocket close reason`_.
|
||||
|
||||
.. _WebSocket close reason:
|
||||
https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6
|
||||
|
||||
:obj:`None` if the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
if self.state is not CLOSED:
|
||||
return None
|
||||
elif self.close_rcvd is None:
|
||||
return ""
|
||||
else:
|
||||
return self.close_rcvd.reason
|
||||
|
||||
@property
|
||||
def close_exc(self) -> ConnectionClosed:
|
||||
"""
|
||||
Exception to raise when trying to interact with a closed connection.
|
||||
|
||||
Don't raise this exception while the connection :attr:`state`
|
||||
is :attr:`~websockets.connection.State.CLOSING`; wait until
|
||||
it's :attr:`~websockets.connection.State.CLOSED`.
|
||||
|
||||
Indeed, the exception includes the close code and reason, which are
|
||||
known only once the connection is closed.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the connection isn't closed yet.
|
||||
|
||||
"""
|
||||
assert self.state is CLOSED, "connection isn't closed yet"
|
||||
exc_type: Type[ConnectionClosed]
|
||||
if (
|
||||
self.close_rcvd is not None
|
||||
and self.close_sent is not None
|
||||
and self.close_rcvd.code in OK_CLOSE_CODES
|
||||
and self.close_sent.code in OK_CLOSE_CODES
|
||||
):
|
||||
exc_type = ConnectionClosedOK
|
||||
else:
|
||||
exc_type = ConnectionClosedError
|
||||
exc: ConnectionClosed = exc_type(
|
||||
self.close_rcvd,
|
||||
self.close_sent,
|
||||
self.close_rcvd_then_sent,
|
||||
)
|
||||
# Chain to the exception raised in the parser, if any.
|
||||
exc.__cause__ = self.parser_exc
|
||||
return exc
|
||||
|
||||
# Public methods for receiving data.
|
||||
|
||||
def receive_data(self, data: bytes) -> None:
|
||||
"""
|
||||
Receive data from the network.
|
||||
|
||||
After calling this method:
|
||||
|
||||
- You must call :meth:`data_to_send` and send this data to the network.
|
||||
- You should call :meth:`events_received` and process resulting events.
|
||||
|
||||
Raises:
|
||||
EOFError: if :meth:`receive_eof` was called earlier.
|
||||
|
||||
"""
|
||||
self.reader.feed_data(data)
|
||||
next(self.parser)
|
||||
|
||||
def receive_eof(self) -> None:
|
||||
"""
|
||||
Receive the end of the data stream from the network.
|
||||
|
||||
After calling this method:
|
||||
|
||||
- You must call :meth:`data_to_send` and send this data to the network.
|
||||
- You aren't expected to call :meth:`events_received`; it won't return
|
||||
any new events.
|
||||
|
||||
Raises:
|
||||
EOFError: if :meth:`receive_eof` was called earlier.
|
||||
|
||||
"""
|
||||
self.reader.feed_eof()
|
||||
next(self.parser)
|
||||
|
||||
# Public methods for sending events.
|
||||
|
||||
def send_continuation(self, data: bytes, fin: bool) -> None:
|
||||
"""
|
||||
Send a `Continuation frame`_.
|
||||
|
||||
.. _Continuation frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing the same kind of data
|
||||
as the initial frame.
|
||||
fin: FIN bit; set it to :obj:`True` if this is the last frame
|
||||
of a fragmented message and to :obj:`False` otherwise.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if a fragmented message isn't in progress.
|
||||
|
||||
"""
|
||||
if not self.expect_continuation_frame:
|
||||
raise ProtocolError("unexpected continuation frame")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_CONT, data, fin))
|
||||
|
||||
def send_text(self, data: bytes, fin: bool = True) -> None:
|
||||
"""
|
||||
Send a `Text frame`_.
|
||||
|
||||
.. _Text frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing text encoded with UTF-8.
|
||||
fin: FIN bit; set it to :obj:`False` if this is the first frame of
|
||||
a fragmented message.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if a fragmented message is in progress.
|
||||
|
||||
"""
|
||||
if self.expect_continuation_frame:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_TEXT, data, fin))
|
||||
|
||||
def send_binary(self, data: bytes, fin: bool = True) -> None:
|
||||
"""
|
||||
Send a `Binary frame`_.
|
||||
|
||||
.. _Binary frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
fin: FIN bit; set it to :obj:`False` if this is the first frame of
|
||||
a fragmented message.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if a fragmented message is in progress.
|
||||
|
||||
"""
|
||||
if self.expect_continuation_frame:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
self.expect_continuation_frame = not fin
|
||||
self.send_frame(Frame(OP_BINARY, data, fin))
|
||||
|
||||
def send_close(self, code: Optional[int] = None, reason: str = "") -> None:
|
||||
"""
|
||||
Send a `Close frame`_.
|
||||
|
||||
.. _Close frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.1
|
||||
|
||||
Parameters:
|
||||
code: close code.
|
||||
reason: close reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if a fragmented message is being sent, if the code
|
||||
isn't valid, or if a reason is provided without a code
|
||||
|
||||
"""
|
||||
if self.expect_continuation_frame:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
if code is None:
|
||||
if reason != "":
|
||||
raise ProtocolError("cannot send a reason without a code")
|
||||
close = Close(1005, "")
|
||||
data = b""
|
||||
else:
|
||||
close = Close(code, reason)
|
||||
data = close.serialize()
|
||||
# send_frame() guarantees that self.state is OPEN at this point.
|
||||
# 7.1.3. The WebSocket Closing Handshake is Started
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
self.close_sent = close
|
||||
self.state = CLOSING
|
||||
|
||||
def send_ping(self, data: bytes) -> None:
|
||||
"""
|
||||
Send a `Ping frame`_.
|
||||
|
||||
.. _Ping frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
|
||||
"""
|
||||
self.send_frame(Frame(OP_PING, data))
|
||||
|
||||
def send_pong(self, data: bytes) -> None:
|
||||
"""
|
||||
Send a `Pong frame`_.
|
||||
|
||||
.. _Pong frame:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3
|
||||
|
||||
Parameters:
|
||||
data: payload containing arbitrary binary data.
|
||||
|
||||
"""
|
||||
self.send_frame(Frame(OP_PONG, data))
|
||||
|
||||
def fail(self, code: int, reason: str = "") -> None:
|
||||
"""
|
||||
`Fail the WebSocket connection`_.
|
||||
|
||||
.. _Fail the WebSocket connection:
|
||||
https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7
|
||||
|
||||
Parameters:
|
||||
code: close code
|
||||
reason: close reason
|
||||
|
||||
Raises:
|
||||
ProtocolError: if the code isn't valid.
|
||||
"""
|
||||
# 7.1.7. Fail the WebSocket Connection
|
||||
|
||||
# Send a close frame when the state is OPEN (a close frame was already
|
||||
# sent if it's CLOSING), except when failing the connection because
|
||||
# of an error reading from or writing to the network.
|
||||
if self.state is OPEN:
|
||||
if code != 1006:
|
||||
close = Close(code, reason)
|
||||
data = close.serialize()
|
||||
self.send_frame(Frame(OP_CLOSE, data))
|
||||
self.close_sent = close
|
||||
self.state = CLOSING
|
||||
|
||||
# When failing the connection, a server closes the TCP connection
|
||||
# without waiting for the client to complete the handshake, while a
|
||||
# client waits for the server to close the TCP connection, possibly
|
||||
# after sending a close frame that the client will ignore.
|
||||
if self.side is SERVER and not self.eof_sent:
|
||||
self.send_eof()
|
||||
|
||||
# 7.1.7. Fail the WebSocket Connection "An endpoint MUST NOT continue
|
||||
# to attempt to process data(including a responding Close frame) from
|
||||
# the remote endpoint after being instructed to _Fail the WebSocket
|
||||
# Connection_."
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
# Public method for getting incoming events after receiving data.
|
||||
|
||||
def events_received(self) -> List[Event]:
|
||||
"""
|
||||
Fetch events generated from data received from the network.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()`` methods.
|
||||
|
||||
Process resulting events, likely by passing them to the application.
|
||||
|
||||
Returns:
|
||||
List[Event]: Events read from the connection.
|
||||
"""
|
||||
events, self.events = self.events, []
|
||||
return events
|
||||
|
||||
# Public method for getting outgoing data after receiving data or sending events.
|
||||
|
||||
def data_to_send(self) -> List[bytes]:
|
||||
"""
|
||||
Obtain data to send to the network.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()``,
|
||||
``send_*()``, or :meth:`fail` methods.
|
||||
|
||||
Write resulting data to the connection.
|
||||
|
||||
The empty bytestring :data:`~websockets.connection.SEND_EOF` signals
|
||||
the end of the data stream. When you receive it, half-close the TCP
|
||||
connection.
|
||||
|
||||
Returns:
|
||||
List[bytes]: Data to write to the connection.
|
||||
|
||||
"""
|
||||
writes, self.writes = self.writes, []
|
||||
return writes
|
||||
|
||||
def close_expected(self) -> bool:
|
||||
"""
|
||||
Tell if the TCP connection is expected to close soon.
|
||||
|
||||
Call this method immediately after any of the ``receive_*()`` or
|
||||
:meth:`fail` methods.
|
||||
|
||||
If it returns :obj:`True`, schedule closing the TCP connection after a
|
||||
short timeout if the other side hasn't already closed it.
|
||||
|
||||
Returns:
|
||||
bool: Whether the TCP connection is expected to close soon.
|
||||
|
||||
"""
|
||||
# We expect a TCP close if and only if we sent a close frame:
|
||||
# * Normal closure: once we send a close frame, we expect a TCP close:
|
||||
# server waits for client to complete the TCP closing handshake;
|
||||
# client waits for server to initiate the TCP closing handshake.
|
||||
# * Abnormal closure: we always send a close frame and the same logic
|
||||
# applies, except on EOFError where we don't send a close frame
|
||||
# because we already received the TCP close, so we don't expect it.
|
||||
# We already got a TCP Close if and only if the state is CLOSED.
|
||||
return self.state is CLOSING or self.handshake_exc is not None
|
||||
|
||||
# Private methods for receiving data.
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Parse incoming data into frames.
|
||||
|
||||
:meth:`receive_data` and :meth:`receive_eof` run this generator
|
||||
coroutine until it needs more data or reaches EOF.
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
if (yield from self.reader.at_eof()):
|
||||
if self.debug:
|
||||
self.logger.debug("< EOF")
|
||||
# If the WebSocket connection is closed cleanly, with a
|
||||
# closing handhshake, recv_frame() substitutes parse()
|
||||
# with discard(). This branch is reached only when the
|
||||
# connection isn't closed cleanly.
|
||||
raise EOFError("unexpected end of stream")
|
||||
|
||||
if self.max_size is None:
|
||||
max_size = None
|
||||
elif self.cur_size is None:
|
||||
max_size = self.max_size
|
||||
else:
|
||||
max_size = self.max_size - self.cur_size
|
||||
|
||||
# During a normal closure, execution ends here on the next
|
||||
# iteration of the loop after receiving a close frame. At
|
||||
# this point, recv_frame() replaced parse() by discard().
|
||||
frame = yield from Frame.parse(
|
||||
self.reader.read_exact,
|
||||
mask=self.side is SERVER,
|
||||
max_size=max_size,
|
||||
extensions=self.extensions,
|
||||
)
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< %s", frame)
|
||||
|
||||
self.recv_frame(frame)
|
||||
|
||||
except ProtocolError as exc:
|
||||
self.fail(1002, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except EOFError as exc:
|
||||
self.fail(1006, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except UnicodeDecodeError as exc:
|
||||
self.fail(1007, f"{exc.reason} at position {exc.start}")
|
||||
self.parser_exc = exc
|
||||
|
||||
except PayloadTooBig as exc:
|
||||
self.fail(1009, str(exc))
|
||||
self.parser_exc = exc
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.error("parser failed", exc_info=True)
|
||||
# Don't include exception details, which may be security-sensitive.
|
||||
self.fail(1011)
|
||||
self.parser_exc = exc
|
||||
|
||||
# During an abnormal closure, execution ends here after catching an
|
||||
# exception. At this point, fail() replaced parse() by discard().
|
||||
yield
|
||||
raise AssertionError("parse() shouldn't step after error") # pragma: no cover
|
||||
|
||||
def discard(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Discard incoming data.
|
||||
|
||||
This coroutine replaces :meth:`parse`:
|
||||
|
||||
- after receiving a close frame, during a normal closure (1.4);
|
||||
- after sending a close frame, during an abnormal closure (7.1.7).
|
||||
|
||||
"""
|
||||
# The server close the TCP connection in the same circumstances where
|
||||
# discard() replaces parse(). The client closes the connection later,
|
||||
# after the server closes the connection or a timeout elapses.
|
||||
# (The latter case cannot be handled in this Sans-I/O layer.)
|
||||
assert (self.side is SERVER) == (self.eof_sent)
|
||||
while not (yield from self.reader.at_eof()):
|
||||
self.reader.discard()
|
||||
if self.debug:
|
||||
self.logger.debug("< EOF")
|
||||
# A server closes the TCP connection immediately, while a client
|
||||
# waits for the server to close the TCP connection.
|
||||
if self.side is CLIENT:
|
||||
self.send_eof()
|
||||
self.state = CLOSED
|
||||
# If discard() completes normally, execution ends here.
|
||||
yield
|
||||
# Once the reader reaches EOF, its feed_data/eof() methods raise an
|
||||
# error, so our receive_data/eof() methods don't step the generator.
|
||||
raise AssertionError("discard() shouldn't step after EOF") # pragma: no cover
|
||||
|
||||
def recv_frame(self, frame: Frame) -> None:
|
||||
"""
|
||||
Process an incoming frame.
|
||||
|
||||
"""
|
||||
if frame.opcode is OP_TEXT or frame.opcode is OP_BINARY:
|
||||
if self.cur_size is not None:
|
||||
raise ProtocolError("expected a continuation frame")
|
||||
if frame.fin:
|
||||
self.cur_size = None
|
||||
else:
|
||||
self.cur_size = len(frame.data)
|
||||
|
||||
elif frame.opcode is OP_CONT:
|
||||
if self.cur_size is None:
|
||||
raise ProtocolError("unexpected continuation frame")
|
||||
if frame.fin:
|
||||
self.cur_size = None
|
||||
else:
|
||||
self.cur_size += len(frame.data)
|
||||
|
||||
elif frame.opcode is OP_PING:
|
||||
# 5.5.2. Ping: "Upon receipt of a Ping frame, an endpoint MUST
|
||||
# send a Pong frame in response"
|
||||
pong_frame = Frame(OP_PONG, frame.data)
|
||||
self.send_frame(pong_frame)
|
||||
|
||||
elif frame.opcode is OP_PONG:
|
||||
# 5.5.3 Pong: "A response to an unsolicited Pong frame is not
|
||||
# expected."
|
||||
pass
|
||||
|
||||
elif frame.opcode is OP_CLOSE:
|
||||
# 7.1.5. The WebSocket Connection Close Code
|
||||
# 7.1.6. The WebSocket Connection Close Reason
|
||||
self.close_rcvd = Close.parse(frame.data)
|
||||
if self.state is CLOSING:
|
||||
assert self.close_sent is not None
|
||||
self.close_rcvd_then_sent = False
|
||||
|
||||
if self.cur_size is not None:
|
||||
raise ProtocolError("incomplete fragmented message")
|
||||
|
||||
# 5.5.1 Close: "If an endpoint receives a Close frame and did
|
||||
# not previously send a Close frame, the endpoint MUST send a
|
||||
# Close frame in response. (When sending a Close frame in
|
||||
# response, the endpoint typically echos the status code it
|
||||
# received.)"
|
||||
|
||||
if self.state is OPEN:
|
||||
# Echo the original data instead of re-serializing it with
|
||||
# Close.serialize() because that fails when the close frame
|
||||
# is empty and Close.parse() synthetizes a 1005 close code.
|
||||
# The rest is identical to send_close().
|
||||
self.send_frame(Frame(OP_CLOSE, frame.data))
|
||||
self.close_sent = self.close_rcvd
|
||||
self.close_rcvd_then_sent = True
|
||||
self.state = CLOSING
|
||||
|
||||
# 7.1.2. Start the WebSocket Closing Handshake: "Once an
|
||||
# endpoint has both sent and received a Close control frame,
|
||||
# that endpoint SHOULD _Close the WebSocket Connection_"
|
||||
|
||||
# A server closes the TCP connection immediately, while a client
|
||||
# waits for the server to close the TCP connection.
|
||||
if self.side is SERVER:
|
||||
self.send_eof()
|
||||
|
||||
# 1.4. Closing Handshake: "after receiving a control frame
|
||||
# indicating the connection should be closed, a peer discards
|
||||
# any further data received."
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
else: # pragma: no cover
|
||||
# This can't happen because Frame.parse() validates opcodes.
|
||||
raise AssertionError(f"unexpected opcode: {frame.opcode:02x}")
|
||||
|
||||
self.events.append(frame)
|
||||
|
||||
# Private methods for sending events.
|
||||
|
||||
def send_frame(self, frame: Frame) -> None:
|
||||
if self.state is not OPEN:
|
||||
raise InvalidState(
|
||||
f"cannot write to a WebSocket in the {self.state.name} state"
|
||||
)
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("> %s", frame)
|
||||
self.writes.append(
|
||||
frame.serialize(mask=self.side is CLIENT, extensions=self.extensions)
|
||||
)
|
||||
|
||||
def send_eof(self) -> None:
|
||||
assert not self.eof_sent
|
||||
self.eof_sent = True
|
||||
if self.debug:
|
||||
self.logger.debug("> EOF")
|
||||
self.writes.append(SEND_EOF)
|
||||
200
venv/lib/python3.11/site-packages/websockets/datastructures.py
Normal file
200
venv/lib/python3.11/site-packages/websockets/datastructures.py
Normal file
@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 8):
|
||||
from typing import Protocol
|
||||
else: # pragma: no cover
|
||||
Protocol = object # mypy will report errors on Python 3.7.
|
||||
|
||||
|
||||
__all__ = ["Headers", "HeadersLike", "MultipleValuesError"]
|
||||
|
||||
|
||||
class MultipleValuesError(LookupError):
|
||||
"""
|
||||
Exception raised when :class:`Headers` has more than one value for a key.
|
||||
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Implement the same logic as KeyError_str in Objects/exceptions.c.
|
||||
if len(self.args) == 1:
|
||||
return repr(self.args[0])
|
||||
return super().__str__()
|
||||
|
||||
|
||||
class Headers(MutableMapping[str, str]):
|
||||
"""
|
||||
Efficient data structure for manipulating HTTP headers.
|
||||
|
||||
A :class:`list` of ``(name, values)`` is inefficient for lookups.
|
||||
|
||||
A :class:`dict` doesn't suffice because header names are case-insensitive
|
||||
and multiple occurrences of headers with the same name are possible.
|
||||
|
||||
:class:`Headers` stores HTTP headers in a hybrid data structure to provide
|
||||
efficient insertions and lookups while preserving the original data.
|
||||
|
||||
In order to account for multiple values with minimal hassle,
|
||||
:class:`Headers` follows this logic:
|
||||
|
||||
- When getting a header with ``headers[name]``:
|
||||
- if there's no value, :exc:`KeyError` is raised;
|
||||
- if there's exactly one value, it's returned;
|
||||
- if there's more than one value, :exc:`MultipleValuesError` is raised.
|
||||
|
||||
- When setting a header with ``headers[name] = value``, the value is
|
||||
appended to the list of values for that header.
|
||||
|
||||
- When deleting a header with ``del headers[name]``, all values for that
|
||||
header are removed (this is slow).
|
||||
|
||||
Other methods for manipulating headers are consistent with this logic.
|
||||
|
||||
As long as no header occurs multiple times, :class:`Headers` behaves like
|
||||
:class:`dict`, except keys are lower-cased to provide case-insensitivity.
|
||||
|
||||
Two methods support manipulating multiple values explicitly:
|
||||
|
||||
- :meth:`get_all` returns a list of all values for a header;
|
||||
- :meth:`raw_items` returns an iterator of ``(name, values)`` pairs.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ["_dict", "_list"]
|
||||
|
||||
# Like dict, Headers accepts an optional "mapping or iterable" argument.
|
||||
def __init__(self, *args: HeadersLike, **kwargs: str) -> None:
|
||||
self._dict: Dict[str, List[str]] = {}
|
||||
self._list: List[Tuple[str, str]] = []
|
||||
self.update(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(f"{key}: {value}\r\n" for key, value in self._list) + "\r\n"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._list!r})"
|
||||
|
||||
def copy(self) -> Headers:
|
||||
copy = self.__class__()
|
||||
copy._dict = self._dict.copy()
|
||||
copy._list = self._list.copy()
|
||||
return copy
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
# Since headers only contain ASCII characters, we can keep this simple.
|
||||
return str(self).encode()
|
||||
|
||||
# Collection methods
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return isinstance(key, str) and key.lower() in self._dict
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._dict)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
# MutableMapping methods
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
value = self._dict[key.lower()]
|
||||
if len(value) == 1:
|
||||
return value[0]
|
||||
else:
|
||||
raise MultipleValuesError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
self._dict.setdefault(key.lower(), []).append(value)
|
||||
self._list.append((key, value))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
key_lower = key.lower()
|
||||
self._dict.__delitem__(key_lower)
|
||||
# This is inefficient. Fortunately deleting HTTP headers is uncommon.
|
||||
self._list = [(k, v) for k, v in self._list if k.lower() != key_lower]
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return NotImplemented
|
||||
return self._dict == other._dict
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Remove all headers.
|
||||
|
||||
"""
|
||||
self._dict = {}
|
||||
self._list = []
|
||||
|
||||
def update(self, *args: HeadersLike, **kwargs: str) -> None:
|
||||
"""
|
||||
Update from a :class:`Headers` instance and/or keyword arguments.
|
||||
|
||||
"""
|
||||
args = tuple(
|
||||
arg.raw_items() if isinstance(arg, Headers) else arg for arg in args
|
||||
)
|
||||
super().update(*args, **kwargs)
|
||||
|
||||
# Methods for handling multiple values
|
||||
|
||||
def get_all(self, key: str) -> List[str]:
|
||||
"""
|
||||
Return the (possibly empty) list of all values for a header.
|
||||
|
||||
Args:
|
||||
key: header name.
|
||||
|
||||
"""
|
||||
return self._dict.get(key.lower(), [])
|
||||
|
||||
def raw_items(self) -> Iterator[Tuple[str, str]]:
|
||||
"""
|
||||
Return an iterator of all values as ``(name, value)`` pairs.
|
||||
|
||||
"""
|
||||
return iter(self._list)
|
||||
|
||||
|
||||
# copy of _typeshed.SupportsKeysAndGetItem.
|
||||
class SupportsKeysAndGetItem(Protocol): # pragma: no cover
|
||||
"""
|
||||
Dict-like types with ``keys() -> str`` and ``__getitem__(key: str) -> str`` methods.
|
||||
|
||||
"""
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
HeadersLike = Union[
|
||||
Headers,
|
||||
Mapping[str, str],
|
||||
Iterable[Tuple[str, str]],
|
||||
SupportsKeysAndGetItem,
|
||||
]
|
||||
"""
|
||||
Types accepted where :class:`Headers` is expected.
|
||||
|
||||
In addition to :class:`Headers` itself, this includes dict-like types where both
|
||||
keys and values are :class:`str`.
|
||||
|
||||
"""
|
||||
398
venv/lib/python3.11/site-packages/websockets/exceptions.py
Normal file
398
venv/lib/python3.11/site-packages/websockets/exceptions.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""
|
||||
:mod:`websockets.exceptions` defines the following exception hierarchy:
|
||||
|
||||
* :exc:`WebSocketException`
|
||||
* :exc:`ConnectionClosed`
|
||||
* :exc:`ConnectionClosedError`
|
||||
* :exc:`ConnectionClosedOK`
|
||||
* :exc:`InvalidHandshake`
|
||||
* :exc:`SecurityError`
|
||||
* :exc:`InvalidMessage`
|
||||
* :exc:`InvalidHeader`
|
||||
* :exc:`InvalidHeaderFormat`
|
||||
* :exc:`InvalidHeaderValue`
|
||||
* :exc:`InvalidOrigin`
|
||||
* :exc:`InvalidUpgrade`
|
||||
* :exc:`InvalidStatus`
|
||||
* :exc:`InvalidStatusCode` (legacy)
|
||||
* :exc:`NegotiationError`
|
||||
* :exc:`DuplicateParameter`
|
||||
* :exc:`InvalidParameterName`
|
||||
* :exc:`InvalidParameterValue`
|
||||
* :exc:`AbortHandshake`
|
||||
* :exc:`RedirectHandshake`
|
||||
* :exc:`InvalidState`
|
||||
* :exc:`InvalidURI`
|
||||
* :exc:`PayloadTooBig`
|
||||
* :exc:`ProtocolError`
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import http
|
||||
from typing import Optional
|
||||
|
||||
from . import datastructures, frames, http11
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WebSocketException",
|
||||
"ConnectionClosed",
|
||||
"ConnectionClosedError",
|
||||
"ConnectionClosedOK",
|
||||
"InvalidHandshake",
|
||||
"SecurityError",
|
||||
"InvalidMessage",
|
||||
"InvalidHeader",
|
||||
"InvalidHeaderFormat",
|
||||
"InvalidHeaderValue",
|
||||
"InvalidOrigin",
|
||||
"InvalidUpgrade",
|
||||
"InvalidStatus",
|
||||
"InvalidStatusCode",
|
||||
"NegotiationError",
|
||||
"DuplicateParameter",
|
||||
"InvalidParameterName",
|
||||
"InvalidParameterValue",
|
||||
"AbortHandshake",
|
||||
"RedirectHandshake",
|
||||
"InvalidState",
|
||||
"InvalidURI",
|
||||
"PayloadTooBig",
|
||||
"ProtocolError",
|
||||
"WebSocketProtocolError",
|
||||
]
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
"""
|
||||
Base class for all exceptions defined by websockets.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosed(WebSocketException):
|
||||
"""
|
||||
Raised when trying to interact with a closed connection.
|
||||
|
||||
Attributes:
|
||||
rcvd (Optional[Close]): if a close frame was received, its code and
|
||||
reason are available in ``rcvd.code`` and ``rcvd.reason``.
|
||||
sent (Optional[Close]): if a close frame was sent, its code and reason
|
||||
are available in ``sent.code`` and ``sent.reason``.
|
||||
rcvd_then_sent (Optional[bool]): if close frames were received and
|
||||
sent, this attribute tells in which order this happened, from the
|
||||
perspective of this side of the connection.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rcvd: Optional[frames.Close],
|
||||
sent: Optional[frames.Close],
|
||||
rcvd_then_sent: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.rcvd = rcvd
|
||||
self.sent = sent
|
||||
self.rcvd_then_sent = rcvd_then_sent
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.rcvd is None:
|
||||
if self.sent is None:
|
||||
assert self.rcvd_then_sent is None
|
||||
return "no close frame received or sent"
|
||||
else:
|
||||
assert self.rcvd_then_sent is None
|
||||
return f"sent {self.sent}; no close frame received"
|
||||
else:
|
||||
if self.sent is None:
|
||||
assert self.rcvd_then_sent is None
|
||||
return f"received {self.rcvd}; no close frame sent"
|
||||
else:
|
||||
assert self.rcvd_then_sent is not None
|
||||
if self.rcvd_then_sent:
|
||||
return f"received {self.rcvd}; then sent {self.sent}"
|
||||
else:
|
||||
return f"sent {self.sent}; then received {self.rcvd}"
|
||||
|
||||
# code and reason attributes are provided for backwards-compatibility
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
return 1006 if self.rcvd is None else self.rcvd.code
|
||||
|
||||
@property
|
||||
def reason(self) -> str:
|
||||
return "" if self.rcvd is None else self.rcvd.reason
|
||||
|
||||
|
||||
class ConnectionClosedError(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated with an error.
|
||||
|
||||
A close code other than 1000 (OK) or 1001 (going away) was received or
|
||||
sent, or the closing handshake didn't complete properly.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ConnectionClosedOK(ConnectionClosed):
|
||||
"""
|
||||
Like :exc:`ConnectionClosed`, when the connection terminated properly.
|
||||
|
||||
A close code 1000 (OK) or 1001 (going away) was received and sent.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidHandshake(WebSocketException):
|
||||
"""
|
||||
Raised during the handshake when the WebSocket connection fails.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SecurityError(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response breaks a security rule.
|
||||
|
||||
Security limits are hard coded.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidMessage(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response is malformed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidHeader(InvalidHandshake):
|
||||
"""
|
||||
Raised when a HTTP header doesn't have a valid format or value.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: Optional[str] = None) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.value is None:
|
||||
return f"missing {self.name} header"
|
||||
elif self.value == "":
|
||||
return f"empty {self.name} header"
|
||||
else:
|
||||
return f"invalid {self.name} header: {self.value}"
|
||||
|
||||
|
||||
class InvalidHeaderFormat(InvalidHeader):
|
||||
"""
|
||||
Raised when a HTTP header cannot be parsed.
|
||||
|
||||
The format of the header doesn't match the grammar for that header.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, error: str, header: str, pos: int) -> None:
|
||||
super().__init__(name, f"{error} at {pos} in {header}")
|
||||
|
||||
|
||||
class InvalidHeaderValue(InvalidHeader):
|
||||
"""
|
||||
Raised when a HTTP header has a wrong value.
|
||||
|
||||
The format of the header is correct but a value isn't acceptable.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidOrigin(InvalidHeader):
|
||||
"""
|
||||
Raised when the Origin header in a request isn't allowed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, origin: Optional[str]) -> None:
|
||||
super().__init__("Origin", origin)
|
||||
|
||||
|
||||
class InvalidUpgrade(InvalidHeader):
|
||||
"""
|
||||
Raised when the Upgrade or Connection header isn't correct.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatus(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response rejects the WebSocket upgrade.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, response: http11.Response) -> None:
|
||||
self.response = response
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"server rejected WebSocket connection: "
|
||||
f"HTTP {self.response.status_code:d}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidStatusCode(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response status code is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"server rejected WebSocket connection: HTTP {self.status_code}"
|
||||
|
||||
|
||||
class NegotiationError(InvalidHandshake):
|
||||
"""
|
||||
Raised when negotiating an extension fails.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DuplicateParameter(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter name is repeated in an extension header.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"duplicate parameter: {self.name}"
|
||||
|
||||
|
||||
class InvalidParameterName(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter name in an extension header is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"invalid parameter name: {self.name}"
|
||||
|
||||
|
||||
class InvalidParameterValue(NegotiationError):
|
||||
"""
|
||||
Raised when a parameter value in an extension header is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: Optional[str]) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.value is None:
|
||||
return f"missing value for parameter {self.name}"
|
||||
elif self.value == "":
|
||||
return f"empty value for parameter {self.name}"
|
||||
else:
|
||||
return f"invalid value for parameter {self.name}: {self.value}"
|
||||
|
||||
|
||||
class AbortHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised to abort the handshake on purpose and return a HTTP response.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
The public API
|
||||
is :meth:`~websockets.server.WebSocketServerProtocol.process_request`.
|
||||
|
||||
Attributes:
|
||||
status (~http.HTTPStatus): HTTP status code.
|
||||
headers (Headers): HTTP response headers.
|
||||
body (bytes): HTTP response body.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: http.HTTPStatus,
|
||||
headers: datastructures.HeadersLike,
|
||||
body: bytes = b"",
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.headers = datastructures.Headers(headers)
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"HTTP {self.status:d}, "
|
||||
f"{len(self.headers)} headers, "
|
||||
f"{len(self.body)} bytes"
|
||||
)
|
||||
|
||||
|
||||
class RedirectHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake gets redirected.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self.uri = uri
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"redirect to {self.uri}"
|
||||
|
||||
|
||||
class InvalidState(WebSocketException, AssertionError):
|
||||
"""
|
||||
Raised when an operation is forbidden in the current state.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
It should never be raised in normal circumstances.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidURI(WebSocketException):
|
||||
"""
|
||||
Raised when connecting to an URI that isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str, msg: str) -> None:
|
||||
self.uri = uri
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.uri} isn't a valid URI: {self.msg}"
|
||||
|
||||
|
||||
class PayloadTooBig(WebSocketException):
|
||||
"""
|
||||
Raised when receiving a frame with a payload exceeding the maximum size.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ProtocolError(WebSocketException):
|
||||
"""
|
||||
Raised when a frame breaks the protocol.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
WebSocketProtocolError = ProtocolError # for backwards compatibility
|
||||
@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
|
||||
|
||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
128
venv/lib/python3.11/site-packages/websockets/extensions/base.py
Normal file
128
venv/lib/python3.11/site-packages/websockets/extensions/base.py
Normal file
@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
from .. import frames
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
|
||||
|
||||
__all__ = ["Extension", "ClientExtensionFactory", "ServerExtensionFactory"]
|
||||
|
||||
|
||||
class Extension:
|
||||
"""
|
||||
Base class for extensions.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def decode(
|
||||
self,
|
||||
frame: frames.Frame,
|
||||
*,
|
||||
max_size: Optional[int] = None,
|
||||
) -> frames.Frame:
|
||||
"""
|
||||
Decode an incoming frame.
|
||||
|
||||
Args:
|
||||
frame (Frame): incoming frame.
|
||||
max_size: maximum payload size in bytes.
|
||||
|
||||
Returns:
|
||||
Frame: Decoded frame.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: if decoding the payload exceeds ``max_size``.
|
||||
|
||||
"""
|
||||
|
||||
def encode(self, frame: frames.Frame) -> frames.Frame:
|
||||
"""
|
||||
Encode an outgoing frame.
|
||||
|
||||
Args:
|
||||
frame (Frame): outgoing frame.
|
||||
|
||||
Returns:
|
||||
Frame: Encoded frame.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ClientExtensionFactory:
|
||||
"""
|
||||
Base class for client-side extension factories.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def get_request_params(self) -> List[ExtensionParameter]:
|
||||
"""
|
||||
Build parameters to send to the server for this extension.
|
||||
|
||||
Returns:
|
||||
List[ExtensionParameter]: Parameters to send to the server.
|
||||
|
||||
"""
|
||||
|
||||
def process_response_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> Extension:
|
||||
"""
|
||||
Process parameters received from the server.
|
||||
|
||||
Args:
|
||||
params (Sequence[ExtensionParameter]): parameters received from
|
||||
the server for this extension.
|
||||
accepted_extensions (Sequence[Extension]): list of previously
|
||||
accepted extensions.
|
||||
|
||||
Returns:
|
||||
Extension: An extension instance.
|
||||
|
||||
Raises:
|
||||
NegotiationError: if parameters aren't acceptable.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ServerExtensionFactory:
|
||||
"""
|
||||
Base class for server-side extension factories.
|
||||
|
||||
"""
|
||||
|
||||
name: ExtensionName
|
||||
"""Extension identifier."""
|
||||
|
||||
def process_request_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> Tuple[List[ExtensionParameter], Extension]:
|
||||
"""
|
||||
Process parameters received from the client.
|
||||
|
||||
Args:
|
||||
params (Sequence[ExtensionParameter]): parameters received from
|
||||
the client for this extension.
|
||||
accepted_extensions (Sequence[Extension]): list of previously
|
||||
accepted extensions.
|
||||
|
||||
Returns:
|
||||
Tuple[List[ExtensionParameter], Extension]: To accept the offer,
|
||||
parameters to send to the client for this extension and an
|
||||
extension instance.
|
||||
|
||||
Raises:
|
||||
NegotiationError: to reject the offer, if parameters received from
|
||||
the client aren't acceptable.
|
||||
|
||||
"""
|
||||
@ -0,0 +1,661 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import zlib
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from .. import exceptions, frames
|
||||
from ..typing import ExtensionName, ExtensionParameter
|
||||
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PerMessageDeflate",
|
||||
"ClientPerMessageDeflateFactory",
|
||||
"enable_client_permessage_deflate",
|
||||
"ServerPerMessageDeflateFactory",
|
||||
"enable_server_permessage_deflate",
|
||||
]
|
||||
|
||||
_EMPTY_UNCOMPRESSED_BLOCK = b"\x00\x00\xff\xff"
|
||||
|
||||
_MAX_WINDOW_BITS_VALUES = [str(bits) for bits in range(8, 16)]
|
||||
|
||||
|
||||
class PerMessageDeflate(Extension):
|
||||
"""
|
||||
Per-Message Deflate extension.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remote_no_context_takeover: bool,
|
||||
local_no_context_takeover: bool,
|
||||
remote_max_window_bits: int,
|
||||
local_max_window_bits: int,
|
||||
compress_settings: Optional[Dict[Any, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension.
|
||||
|
||||
"""
|
||||
if compress_settings is None:
|
||||
compress_settings = {}
|
||||
|
||||
assert remote_no_context_takeover in [False, True]
|
||||
assert local_no_context_takeover in [False, True]
|
||||
assert 8 <= remote_max_window_bits <= 15
|
||||
assert 8 <= local_max_window_bits <= 15
|
||||
assert "wbits" not in compress_settings
|
||||
|
||||
self.remote_no_context_takeover = remote_no_context_takeover
|
||||
self.local_no_context_takeover = local_no_context_takeover
|
||||
self.remote_max_window_bits = remote_max_window_bits
|
||||
self.local_max_window_bits = local_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
|
||||
if not self.remote_no_context_takeover:
|
||||
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
|
||||
|
||||
if not self.local_no_context_takeover:
|
||||
self.encoder = zlib.compressobj(
|
||||
wbits=-self.local_max_window_bits, **self.compress_settings
|
||||
)
|
||||
|
||||
# To handle continuation frames properly, we must keep track of
|
||||
# whether that initial frame was encoded.
|
||||
self.decode_cont_data = False
|
||||
# There's no need for self.encode_cont_data because we always encode
|
||||
# outgoing frames, so it would always be True.
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PerMessageDeflate("
|
||||
f"remote_no_context_takeover={self.remote_no_context_takeover}, "
|
||||
f"local_no_context_takeover={self.local_no_context_takeover}, "
|
||||
f"remote_max_window_bits={self.remote_max_window_bits}, "
|
||||
f"local_max_window_bits={self.local_max_window_bits})"
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
frame: frames.Frame,
|
||||
*,
|
||||
max_size: Optional[int] = None,
|
||||
) -> frames.Frame:
|
||||
"""
|
||||
Decode an incoming frame.
|
||||
|
||||
"""
|
||||
# Skip control frames.
|
||||
if frame.opcode in frames.CTRL_OPCODES:
|
||||
return frame
|
||||
|
||||
# Handle continuation data frames:
|
||||
# - skip if the message isn't encoded
|
||||
# - reset "decode continuation data" flag if it's a final frame
|
||||
if frame.opcode is frames.OP_CONT:
|
||||
if not self.decode_cont_data:
|
||||
return frame
|
||||
if frame.fin:
|
||||
self.decode_cont_data = False
|
||||
|
||||
# Handle text and binary data frames:
|
||||
# - skip if the message isn't encoded
|
||||
# - unset the rsv1 flag on the first frame of a compressed message
|
||||
# - set "decode continuation data" flag if it's a non-final frame
|
||||
else:
|
||||
if not frame.rsv1:
|
||||
return frame
|
||||
frame = dataclasses.replace(frame, rsv1=False)
|
||||
if not frame.fin:
|
||||
self.decode_cont_data = True
|
||||
|
||||
# Re-initialize per-message decoder.
|
||||
if self.remote_no_context_takeover:
|
||||
self.decoder = zlib.decompressobj(wbits=-self.remote_max_window_bits)
|
||||
|
||||
# Uncompress data. Protect against zip bombs by preventing zlib from
|
||||
# decompressing more than max_length bytes (except when the limit is
|
||||
# disabled with max_size = None).
|
||||
data = frame.data
|
||||
if frame.fin:
|
||||
data += _EMPTY_UNCOMPRESSED_BLOCK
|
||||
max_length = 0 if max_size is None else max_size
|
||||
try:
|
||||
data = self.decoder.decompress(data, max_length)
|
||||
except zlib.error as exc:
|
||||
raise exceptions.ProtocolError("decompression failed") from exc
|
||||
if self.decoder.unconsumed_tail:
|
||||
raise exceptions.PayloadTooBig(f"over size limit (? > {max_size} bytes)")
|
||||
|
||||
# Allow garbage collection of the decoder if it won't be reused.
|
||||
if frame.fin and self.remote_no_context_takeover:
|
||||
del self.decoder
|
||||
|
||||
return dataclasses.replace(frame, data=data)
|
||||
|
||||
def encode(self, frame: frames.Frame) -> frames.Frame:
|
||||
"""
|
||||
Encode an outgoing frame.
|
||||
|
||||
"""
|
||||
# Skip control frames.
|
||||
if frame.opcode in frames.CTRL_OPCODES:
|
||||
return frame
|
||||
|
||||
# Since we always encode messages, there's no "encode continuation
|
||||
# data" flag similar to "decode continuation data" at this time.
|
||||
|
||||
if frame.opcode is not frames.OP_CONT:
|
||||
# Set the rsv1 flag on the first frame of a compressed message.
|
||||
frame = dataclasses.replace(frame, rsv1=True)
|
||||
# Re-initialize per-message decoder.
|
||||
if self.local_no_context_takeover:
|
||||
self.encoder = zlib.compressobj(
|
||||
wbits=-self.local_max_window_bits, **self.compress_settings
|
||||
)
|
||||
|
||||
# Compress data.
|
||||
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
|
||||
if frame.fin and data.endswith(_EMPTY_UNCOMPRESSED_BLOCK):
|
||||
data = data[:-4]
|
||||
|
||||
# Allow garbage collection of the encoder if it won't be reused.
|
||||
if frame.fin and self.local_no_context_takeover:
|
||||
del self.encoder
|
||||
|
||||
return dataclasses.replace(frame, data=data)
|
||||
|
||||
|
||||
def _build_parameters(
|
||||
server_no_context_takeover: bool,
|
||||
client_no_context_takeover: bool,
|
||||
server_max_window_bits: Optional[int],
|
||||
client_max_window_bits: Optional[Union[int, bool]],
|
||||
) -> List[ExtensionParameter]:
|
||||
"""
|
||||
Build a list of ``(name, value)`` pairs for some compression parameters.
|
||||
|
||||
"""
|
||||
params: List[ExtensionParameter] = []
|
||||
if server_no_context_takeover:
|
||||
params.append(("server_no_context_takeover", None))
|
||||
if client_no_context_takeover:
|
||||
params.append(("client_no_context_takeover", None))
|
||||
if server_max_window_bits:
|
||||
params.append(("server_max_window_bits", str(server_max_window_bits)))
|
||||
if client_max_window_bits is True: # only in handshake requests
|
||||
params.append(("client_max_window_bits", None))
|
||||
elif client_max_window_bits:
|
||||
params.append(("client_max_window_bits", str(client_max_window_bits)))
|
||||
return params
|
||||
|
||||
|
||||
def _extract_parameters(
|
||||
params: Sequence[ExtensionParameter], *, is_server: bool
|
||||
) -> Tuple[bool, bool, Optional[int], Optional[Union[int, bool]]]:
|
||||
"""
|
||||
Extract compression parameters from a list of ``(name, value)`` pairs.
|
||||
|
||||
If ``is_server`` is :obj:`True`, ``client_max_window_bits`` may be
|
||||
provided without a value. This is only allowed in handshake requests.
|
||||
|
||||
"""
|
||||
server_no_context_takeover: bool = False
|
||||
client_no_context_takeover: bool = False
|
||||
server_max_window_bits: Optional[int] = None
|
||||
client_max_window_bits: Optional[Union[int, bool]] = None
|
||||
|
||||
for name, value in params:
|
||||
|
||||
if name == "server_no_context_takeover":
|
||||
if server_no_context_takeover:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
if value is None:
|
||||
server_no_context_takeover = True
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_no_context_takeover":
|
||||
if client_no_context_takeover:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
if value is None:
|
||||
client_no_context_takeover = True
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "server_max_window_bits":
|
||||
if server_max_window_bits is not None:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
if value in _MAX_WINDOW_BITS_VALUES:
|
||||
server_max_window_bits = int(value)
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
|
||||
elif name == "client_max_window_bits":
|
||||
if client_max_window_bits is not None:
|
||||
raise exceptions.DuplicateParameter(name)
|
||||
if is_server and value is None: # only in handshake requests
|
||||
client_max_window_bits = True
|
||||
elif value in _MAX_WINDOW_BITS_VALUES:
|
||||
client_max_window_bits = int(value)
|
||||
else:
|
||||
raise exceptions.InvalidParameterValue(name, value)
|
||||
|
||||
else:
|
||||
raise exceptions.InvalidParameterName(name)
|
||||
|
||||
return (
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
)
|
||||
|
||||
|
||||
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
|
||||
"""
|
||||
Client-side extension factory for the Per-Message Deflate extension.
|
||||
|
||||
Parameters behave as described in `section 7.1 of RFC 7692`_.
|
||||
|
||||
.. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1
|
||||
|
||||
Set them to :obj:`True` to include them in the negotiation offer without a
|
||||
value or to an integer value to include them with this value.
|
||||
|
||||
Args:
|
||||
server_no_context_takeover: prevent server from using context takeover.
|
||||
client_no_context_takeover: prevent client from using context takeover.
|
||||
server_max_window_bits: maximum size of the server's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
client_max_window_bits: maximum size of the client's LZ77 sliding window
|
||||
in bits, between 8 and 15, or :obj:`True` to indicate support without
|
||||
setting a limit.
|
||||
compress_settings: additional keyword arguments for :func:`zlib.compressobj`,
|
||||
excluding ``wbits``.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_no_context_takeover: bool = False,
|
||||
client_no_context_takeover: bool = False,
|
||||
server_max_window_bits: Optional[int] = None,
|
||||
client_max_window_bits: Optional[Union[int, bool]] = True,
|
||||
compress_settings: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension factory.
|
||||
|
||||
"""
|
||||
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
|
||||
raise ValueError("server_max_window_bits must be between 8 and 15")
|
||||
if not (
|
||||
client_max_window_bits is None
|
||||
or client_max_window_bits is True
|
||||
or 8 <= client_max_window_bits <= 15
|
||||
):
|
||||
raise ValueError("client_max_window_bits must be between 8 and 15")
|
||||
if compress_settings is not None and "wbits" in compress_settings:
|
||||
raise ValueError(
|
||||
"compress_settings must not include wbits, "
|
||||
"set client_max_window_bits instead"
|
||||
)
|
||||
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
|
||||
def get_request_params(self) -> List[ExtensionParameter]:
|
||||
"""
|
||||
Build request parameters.
|
||||
|
||||
"""
|
||||
return _build_parameters(
|
||||
self.server_no_context_takeover,
|
||||
self.client_no_context_takeover,
|
||||
self.server_max_window_bits,
|
||||
self.client_max_window_bits,
|
||||
)
|
||||
|
||||
def process_response_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> PerMessageDeflate:
|
||||
"""
|
||||
Process response parameters.
|
||||
|
||||
Return an extension instance.
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise exceptions.NegotiationError(f"received duplicate {self.name}")
|
||||
|
||||
# Request parameters are available in instance variables.
|
||||
|
||||
# Load response parameters in local variables.
|
||||
(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
) = _extract_parameters(params, is_server=False)
|
||||
|
||||
# After comparing the request and the response, the final
|
||||
# configuration must be available in the local variables.
|
||||
|
||||
# server_no_context_takeover
|
||||
#
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False Error!
|
||||
# True True True
|
||||
|
||||
if self.server_no_context_takeover:
|
||||
if not server_no_context_takeover:
|
||||
raise exceptions.NegotiationError("expected server_no_context_takeover")
|
||||
|
||||
# client_no_context_takeover
|
||||
#
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False True - must change value
|
||||
# True True True
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
if not client_no_context_takeover:
|
||||
client_no_context_takeover = True
|
||||
|
||||
# server_max_window_bits
|
||||
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 M
|
||||
# 8≤N≤15 None Error!
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 Error!
|
||||
|
||||
if self.server_max_window_bits is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
if server_max_window_bits is None:
|
||||
raise exceptions.NegotiationError("expected server_max_window_bits")
|
||||
elif server_max_window_bits > self.server_max_window_bits:
|
||||
raise exceptions.NegotiationError("unsupported server_max_window_bits")
|
||||
|
||||
# client_max_window_bits
|
||||
|
||||
# Req. Resp. Result
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 Error!
|
||||
# True None None
|
||||
# True 8≤M≤15 M
|
||||
# 8≤N≤15 None N - must change value
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 Error!
|
||||
|
||||
if self.client_max_window_bits is None:
|
||||
if client_max_window_bits is not None:
|
||||
raise exceptions.NegotiationError("unexpected client_max_window_bits")
|
||||
|
||||
elif self.client_max_window_bits is True:
|
||||
pass
|
||||
|
||||
else:
|
||||
if client_max_window_bits is None:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif client_max_window_bits > self.client_max_window_bits:
|
||||
raise exceptions.NegotiationError("unsupported client_max_window_bits")
|
||||
|
||||
return PerMessageDeflate(
|
||||
server_no_context_takeover, # remote_no_context_takeover
|
||||
client_no_context_takeover, # local_no_context_takeover
|
||||
server_max_window_bits or 15, # remote_max_window_bits
|
||||
client_max_window_bits or 15, # local_max_window_bits
|
||||
self.compress_settings,
|
||||
)
|
||||
|
||||
|
||||
def enable_client_permessage_deflate(
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]],
|
||||
) -> Sequence[ClientExtensionFactory]:
|
||||
"""
|
||||
Enable Per-Message Deflate with default settings in client extensions.
|
||||
|
||||
If the extension is already present, perhaps with non-default settings,
|
||||
the configuration isn't changed.
|
||||
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
if not any(
|
||||
extension_factory.name == ClientPerMessageDeflateFactory.name
|
||||
for extension_factory in extensions
|
||||
):
|
||||
extensions = list(extensions) + [
|
||||
ClientPerMessageDeflateFactory(
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
return extensions
|
||||
|
||||
|
||||
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
|
||||
"""
|
||||
Server-side extension factory for the Per-Message Deflate extension.
|
||||
|
||||
Parameters behave as described in `section 7.1 of RFC 7692`_.
|
||||
|
||||
.. _section 7.1 of RFC 7692: https://www.rfc-editor.org/rfc/rfc7692.html#section-7.1
|
||||
|
||||
Set them to :obj:`True` to include them in the negotiation offer without a
|
||||
value or to an integer value to include them with this value.
|
||||
|
||||
Args:
|
||||
server_no_context_takeover: prevent server from using context takeover.
|
||||
client_no_context_takeover: prevent client from using context takeover.
|
||||
server_max_window_bits: maximum size of the server's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
client_max_window_bits: maximum size of the client's LZ77 sliding window
|
||||
in bits, between 8 and 15.
|
||||
compress_settings: additional keyword arguments for :func:`zlib.compressobj`,
|
||||
excluding ``wbits``.
|
||||
require_client_max_window_bits: do not enable compression at all if
|
||||
client doesn't advertise support for ``client_max_window_bits``;
|
||||
the default behavior is to enable compression without enforcing
|
||||
``client_max_window_bits``.
|
||||
|
||||
"""
|
||||
|
||||
name = ExtensionName("permessage-deflate")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_no_context_takeover: bool = False,
|
||||
client_no_context_takeover: bool = False,
|
||||
server_max_window_bits: Optional[int] = None,
|
||||
client_max_window_bits: Optional[int] = None,
|
||||
compress_settings: Optional[Dict[str, Any]] = None,
|
||||
require_client_max_window_bits: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Configure the Per-Message Deflate extension factory.
|
||||
|
||||
"""
|
||||
if not (server_max_window_bits is None or 8 <= server_max_window_bits <= 15):
|
||||
raise ValueError("server_max_window_bits must be between 8 and 15")
|
||||
if not (client_max_window_bits is None or 8 <= client_max_window_bits <= 15):
|
||||
raise ValueError("client_max_window_bits must be between 8 and 15")
|
||||
if compress_settings is not None and "wbits" in compress_settings:
|
||||
raise ValueError(
|
||||
"compress_settings must not include wbits, "
|
||||
"set server_max_window_bits instead"
|
||||
)
|
||||
if client_max_window_bits is None and require_client_max_window_bits:
|
||||
raise ValueError(
|
||||
"require_client_max_window_bits is enabled, "
|
||||
"but client_max_window_bits isn't configured"
|
||||
)
|
||||
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
self.compress_settings = compress_settings
|
||||
self.require_client_max_window_bits = require_client_max_window_bits
|
||||
|
||||
def process_request_params(
|
||||
self,
|
||||
params: Sequence[ExtensionParameter],
|
||||
accepted_extensions: Sequence[Extension],
|
||||
) -> Tuple[List[ExtensionParameter], PerMessageDeflate]:
|
||||
"""
|
||||
Process request parameters.
|
||||
|
||||
Return response params and an extension instance.
|
||||
|
||||
"""
|
||||
if any(other.name == self.name for other in accepted_extensions):
|
||||
raise exceptions.NegotiationError(f"skipped duplicate {self.name}")
|
||||
|
||||
# Load request parameters in local variables.
|
||||
(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
) = _extract_parameters(params, is_server=True)
|
||||
|
||||
# Configuration parameters are available in instance variables.
|
||||
|
||||
# After comparing the request and the configuration, the response must
|
||||
# be available in the local variables.
|
||||
|
||||
# server_no_context_takeover
|
||||
#
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True
|
||||
# True False True - must change value to True
|
||||
# True True True
|
||||
|
||||
if self.server_no_context_takeover:
|
||||
if not server_no_context_takeover:
|
||||
server_no_context_takeover = True
|
||||
|
||||
# client_no_context_takeover
|
||||
#
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# False False False
|
||||
# False True True (or False)
|
||||
# True False True - must change value to True
|
||||
# True True True (or False)
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
if not client_no_context_takeover:
|
||||
client_no_context_takeover = True
|
||||
|
||||
# server_max_window_bits
|
||||
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None 8≤M≤15 M
|
||||
# 8≤N≤15 None N - must change value
|
||||
# 8≤N≤15 8≤M≤N M
|
||||
# 8≤N≤15 N<M≤15 N - must change value
|
||||
|
||||
if self.server_max_window_bits is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
if server_max_window_bits is None:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
elif server_max_window_bits > self.server_max_window_bits:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
|
||||
# client_max_window_bits
|
||||
|
||||
# Config Req. Resp.
|
||||
# ------ ------ --------------------------------------------------
|
||||
# None None None
|
||||
# None True None - must change value
|
||||
# None 8≤M≤15 M (or None)
|
||||
# 8≤N≤15 None None or Error!
|
||||
# 8≤N≤15 True N - must change value
|
||||
# 8≤N≤15 8≤M≤N M (or None)
|
||||
# 8≤N≤15 N<M≤15 N
|
||||
|
||||
if self.client_max_window_bits is None:
|
||||
if client_max_window_bits is True:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
|
||||
else:
|
||||
if client_max_window_bits is None:
|
||||
if self.require_client_max_window_bits:
|
||||
raise exceptions.NegotiationError("required client_max_window_bits")
|
||||
elif client_max_window_bits is True:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif self.client_max_window_bits < client_max_window_bits:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
|
||||
return (
|
||||
_build_parameters(
|
||||
server_no_context_takeover,
|
||||
client_no_context_takeover,
|
||||
server_max_window_bits,
|
||||
client_max_window_bits,
|
||||
),
|
||||
PerMessageDeflate(
|
||||
client_no_context_takeover, # remote_no_context_takeover
|
||||
server_no_context_takeover, # local_no_context_takeover
|
||||
client_max_window_bits or 15, # remote_max_window_bits
|
||||
server_max_window_bits or 15, # local_max_window_bits
|
||||
self.compress_settings,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def enable_server_permessage_deflate(
|
||||
extensions: Optional[Sequence[ServerExtensionFactory]],
|
||||
) -> Sequence[ServerExtensionFactory]:
|
||||
"""
|
||||
Enable Per-Message Deflate with default settings in server extensions.
|
||||
|
||||
If the extension is already present, perhaps with non-default settings,
|
||||
the configuration isn't changed.
|
||||
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
if not any(
|
||||
ext_factory.name == ServerPerMessageDeflateFactory.name
|
||||
for ext_factory in extensions
|
||||
):
|
||||
extensions = list(extensions) + [
|
||||
ServerPerMessageDeflateFactory(
|
||||
server_max_window_bits=12,
|
||||
client_max_window_bits=12,
|
||||
compress_settings={"memLevel": 5},
|
||||
)
|
||||
]
|
||||
return extensions
|
||||
445
venv/lib/python3.11/site-packages/websockets/frames.py
Normal file
445
venv/lib/python3.11/site-packages/websockets/frames.py
Normal file
@ -0,0 +1,445 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
import io
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Callable, Generator, Optional, Sequence, Tuple
|
||||
|
||||
from . import exceptions, extensions
|
||||
from .typing import Data
|
||||
|
||||
|
||||
try:
|
||||
from .speedups import apply_mask
|
||||
except ImportError: # pragma: no cover
|
||||
from .utils import apply_mask
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Opcode",
|
||||
"OP_CONT",
|
||||
"OP_TEXT",
|
||||
"OP_BINARY",
|
||||
"OP_CLOSE",
|
||||
"OP_PING",
|
||||
"OP_PONG",
|
||||
"DATA_OPCODES",
|
||||
"CTRL_OPCODES",
|
||||
"Frame",
|
||||
"prepare_data",
|
||||
"prepare_ctrl",
|
||||
"Close",
|
||||
]
|
||||
|
||||
|
||||
class Opcode(enum.IntEnum):
|
||||
"""Opcode values for WebSocket frames."""
|
||||
|
||||
CONT, TEXT, BINARY = 0x00, 0x01, 0x02
|
||||
CLOSE, PING, PONG = 0x08, 0x09, 0x0A
|
||||
|
||||
|
||||
OP_CONT = Opcode.CONT
|
||||
OP_TEXT = Opcode.TEXT
|
||||
OP_BINARY = Opcode.BINARY
|
||||
OP_CLOSE = Opcode.CLOSE
|
||||
OP_PING = Opcode.PING
|
||||
OP_PONG = Opcode.PONG
|
||||
|
||||
DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
|
||||
CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
|
||||
|
||||
|
||||
# See https://www.iana.org/assignments/websocket/websocket.xhtml
|
||||
CLOSE_CODES = {
|
||||
1000: "OK",
|
||||
1001: "going away",
|
||||
1002: "protocol error",
|
||||
1003: "unsupported type",
|
||||
# 1004 is reserved
|
||||
1005: "no status code [internal]",
|
||||
1006: "connection closed abnormally [internal]",
|
||||
1007: "invalid data",
|
||||
1008: "policy violation",
|
||||
1009: "message too big",
|
||||
1010: "extension required",
|
||||
1011: "unexpected error",
|
||||
1012: "service restart",
|
||||
1013: "try again later",
|
||||
1014: "bad gateway",
|
||||
1015: "TLS failure [internal]",
|
||||
}
|
||||
|
||||
|
||||
# Close code that are allowed in a close frame.
|
||||
# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
|
||||
EXTERNAL_CLOSE_CODES = {
|
||||
1000,
|
||||
1001,
|
||||
1002,
|
||||
1003,
|
||||
1007,
|
||||
1008,
|
||||
1009,
|
||||
1010,
|
||||
1011,
|
||||
1012,
|
||||
1013,
|
||||
1014,
|
||||
}
|
||||
|
||||
OK_CLOSE_CODES = {1000, 1001}
|
||||
|
||||
|
||||
BytesLike = bytes, bytearray, memoryview
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Frame:
|
||||
"""
|
||||
WebSocket frame.
|
||||
|
||||
Attributes:
|
||||
opcode: Opcode.
|
||||
data: Payload data.
|
||||
fin: FIN bit.
|
||||
rsv1: RSV1 bit.
|
||||
rsv2: RSV2 bit.
|
||||
rsv3: RSV3 bit.
|
||||
|
||||
Only these fields are needed. The MASK bit, payload length and masking-key
|
||||
are handled on the fly when parsing and serializing frames.
|
||||
|
||||
"""
|
||||
|
||||
opcode: Opcode
|
||||
data: bytes
|
||||
fin: bool = True
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a human-readable represention of a frame.
|
||||
|
||||
"""
|
||||
coding = None
|
||||
length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
|
||||
non_final = "" if self.fin else "continued"
|
||||
|
||||
if self.opcode is OP_TEXT:
|
||||
# Decoding only the beginning and the end is needlessly hard.
|
||||
# Decode the entire payload then elide later if necessary.
|
||||
data = repr(self.data.decode())
|
||||
elif self.opcode is OP_BINARY:
|
||||
# We'll show at most the first 16 bytes and the last 8 bytes.
|
||||
# Encode just what we need, plus two dummy bytes to elide later.
|
||||
binary = self.data
|
||||
if len(binary) > 25:
|
||||
binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]])
|
||||
data = " ".join(f"{byte:02x}" for byte in binary)
|
||||
elif self.opcode is OP_CLOSE:
|
||||
data = str(Close.parse(self.data))
|
||||
elif self.data:
|
||||
# We don't know if a Continuation frame contains text or binary.
|
||||
# Ping and Pong frames could contain UTF-8.
|
||||
# Attempt to decode as UTF-8 and display it as text; fallback to
|
||||
# binary. If self.data is a memoryview, it has no decode() method,
|
||||
# which raises AttributeError.
|
||||
try:
|
||||
data = repr(self.data.decode())
|
||||
coding = "text"
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
binary = self.data
|
||||
if len(binary) > 25:
|
||||
binary = b"".join([binary[:16], b"\x00\x00", binary[-8:]])
|
||||
data = " ".join(f"{byte:02x}" for byte in binary)
|
||||
coding = "binary"
|
||||
else:
|
||||
data = "''"
|
||||
|
||||
if len(data) > 75:
|
||||
data = data[:48] + "..." + data[-24:]
|
||||
|
||||
metadata = ", ".join(filter(None, [coding, length, non_final]))
|
||||
|
||||
return f"{self.opcode.name} {data} [{metadata}]"
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_exact: Callable[[int], Generator[None, None, bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: Optional[int] = None,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> Generator[None, None, Frame]:
|
||||
"""
|
||||
Parse a WebSocket frame.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
read_exact: generator-based coroutine that reads the requested
|
||||
bytes or raises an exception if there isn't enough data.
|
||||
mask: whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: maximum payload size in bytes.
|
||||
extensions: list of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a full WebSocket frame.
|
||||
UnicodeDecodeError: if the frame contains invalid UTF-8.
|
||||
PayloadTooBig: if the frame's payload size exceeds ``max_size``.
|
||||
ProtocolError: if the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# Read the header.
|
||||
data = yield from read_exact(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise exceptions.ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise exceptions.ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = yield from read_exact(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = yield from read_exact(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise exceptions.PayloadTooBig(
|
||||
f"over size limit ({length} > {max_size} bytes)"
|
||||
)
|
||||
if mask:
|
||||
mask_bytes = yield from read_exact(4)
|
||||
|
||||
# Read the data.
|
||||
data = yield from read_exact(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bytes)
|
||||
|
||||
frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
frame = extension.decode(frame, max_size=max_size)
|
||||
|
||||
frame.check()
|
||||
|
||||
return frame
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket frame.
|
||||
|
||||
Args:
|
||||
mask: whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: list of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
self.check()
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in extensions:
|
||||
self = extension.encode(self)
|
||||
|
||||
output = io.BytesIO()
|
||||
|
||||
# Prepare the header.
|
||||
head1 = (
|
||||
(0b10000000 if self.fin else 0)
|
||||
| (0b01000000 if self.rsv1 else 0)
|
||||
| (0b00100000 if self.rsv2 else 0)
|
||||
| (0b00010000 if self.rsv3 else 0)
|
||||
| self.opcode
|
||||
)
|
||||
|
||||
head2 = 0b10000000 if mask else 0
|
||||
|
||||
length = len(self.data)
|
||||
if length < 126:
|
||||
output.write(struct.pack("!BB", head1, head2 | length))
|
||||
elif length < 65536:
|
||||
output.write(struct.pack("!BBH", head1, head2 | 126, length))
|
||||
else:
|
||||
output.write(struct.pack("!BBQ", head1, head2 | 127, length))
|
||||
|
||||
if mask:
|
||||
mask_bytes = secrets.token_bytes(4)
|
||||
output.write(mask_bytes)
|
||||
|
||||
# Prepare the data.
|
||||
if mask:
|
||||
data = apply_mask(self.data, mask_bytes)
|
||||
else:
|
||||
data = self.data
|
||||
output.write(data)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check that reserved bits and opcode have acceptable values.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if a reserved bit or the opcode is invalid.
|
||||
|
||||
"""
|
||||
if self.rsv1 or self.rsv2 or self.rsv3:
|
||||
raise exceptions.ProtocolError("reserved bits must be 0")
|
||||
|
||||
if self.opcode in CTRL_OPCODES:
|
||||
if len(self.data) > 125:
|
||||
raise exceptions.ProtocolError("control frame too long")
|
||||
if not self.fin:
|
||||
raise exceptions.ProtocolError("fragmented control frame")
|
||||
|
||||
|
||||
def prepare_data(data: Data) -> Tuple[int, bytes]:
|
||||
"""
|
||||
Convert a string or byte-like object to an opcode and a bytes-like object.
|
||||
|
||||
This function is designed for data frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
|
||||
object encoding ``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
|
||||
object.
|
||||
|
||||
Raises:
|
||||
TypeError: if ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return OP_TEXT, data.encode("utf-8")
|
||||
elif isinstance(data, BytesLike):
|
||||
return OP_BINARY, data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
def prepare_ctrl(data: Data) -> bytes:
|
||||
"""
|
||||
Convert a string or byte-like object to bytes.
|
||||
|
||||
This function is designed for ping and pong frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
|
||||
``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return a :class:`bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: if ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.encode("utf-8")
|
||||
elif isinstance(data, BytesLike):
|
||||
return bytes(data)
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Close:
|
||||
"""
|
||||
Code and reason for WebSocket close frames.
|
||||
|
||||
Attributes:
|
||||
code: Close code.
|
||||
reason: Close reason.
|
||||
|
||||
"""
|
||||
|
||||
code: int
|
||||
reason: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a human-readable represention of a close code and reason.
|
||||
|
||||
"""
|
||||
if 3000 <= self.code < 4000:
|
||||
explanation = "registered"
|
||||
elif 4000 <= self.code < 5000:
|
||||
explanation = "private use"
|
||||
else:
|
||||
explanation = CLOSE_CODES.get(self.code, "unknown")
|
||||
result = f"{self.code} ({explanation})"
|
||||
|
||||
if self.reason:
|
||||
result = f"{result} {self.reason}"
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data: bytes) -> Close:
|
||||
"""
|
||||
Parse the payload of a close frame.
|
||||
|
||||
Args:
|
||||
data: payload of the close frame.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if data is ill-formed.
|
||||
UnicodeDecodeError: if the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
if len(data) >= 2:
|
||||
(code,) = struct.unpack("!H", data[:2])
|
||||
reason = data[2:].decode("utf-8")
|
||||
close = cls(code, reason)
|
||||
close.check()
|
||||
return close
|
||||
elif len(data) == 0:
|
||||
return cls(1005, "")
|
||||
else:
|
||||
raise exceptions.ProtocolError("close frame too short")
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize the payload of a close frame.
|
||||
|
||||
"""
|
||||
self.check()
|
||||
return struct.pack("!H", self.code) + self.reason.encode("utf-8")
|
||||
|
||||
def check(self) -> None:
|
||||
"""
|
||||
Check that the close code has a valid value for a close frame.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if the close code is invalid.
|
||||
|
||||
"""
|
||||
if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
|
||||
raise exceptions.ProtocolError("invalid status code")
|
||||
587
venv/lib/python3.11/site-packages/websockets/headers.py
Normal file
587
venv/lib/python3.11/site-packages/websockets/headers.py
Normal file
@ -0,0 +1,587 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import ipaddress
|
||||
import re
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, cast
|
||||
|
||||
from . import exceptions
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
ExtensionName,
|
||||
ExtensionParameter,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_host",
|
||||
"parse_connection",
|
||||
"parse_upgrade",
|
||||
"parse_extension",
|
||||
"build_extension",
|
||||
"parse_subprotocol",
|
||||
"build_subprotocol",
|
||||
"validate_subprotocols",
|
||||
"build_www_authenticate_basic",
|
||||
"parse_authorization_basic",
|
||||
"build_authorization_basic",
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def build_host(host: str, port: int, secure: bool) -> str:
|
||||
"""
|
||||
Build a ``Host`` header.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.2
|
||||
# IPv6 addresses must be enclosed in brackets.
|
||||
try:
|
||||
address = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
# host is a hostname
|
||||
pass
|
||||
else:
|
||||
# host is an IP address
|
||||
if address.version == 6:
|
||||
host = f"[{host}]"
|
||||
|
||||
if port != (443 if secure else 80):
|
||||
host = f"{host}:{port}"
|
||||
|
||||
return host
|
||||
|
||||
|
||||
# To avoid a dependency on a parsing library, we implement manually the ABNF
|
||||
# described in https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 and
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B.
|
||||
|
||||
|
||||
def peek_ahead(header: str, pos: int) -> Optional[str]:
|
||||
"""
|
||||
Return the next character from ``header`` at the given position.
|
||||
|
||||
Return :obj:`None` at the end of ``header``.
|
||||
|
||||
We never need to peek more than one character ahead.
|
||||
|
||||
"""
|
||||
return None if pos == len(header) else header[pos]
|
||||
|
||||
|
||||
_OWS_re = re.compile(r"[\t ]*")
|
||||
|
||||
|
||||
def parse_OWS(header: str, pos: int) -> int:
|
||||
"""
|
||||
Parse optional whitespace from ``header`` at the given position.
|
||||
|
||||
Return the new position.
|
||||
|
||||
The whitespace itself isn't returned because it isn't significant.
|
||||
|
||||
"""
|
||||
# There's always a match, possibly empty, whose content doesn't matter.
|
||||
match = _OWS_re.match(header, pos)
|
||||
assert match is not None
|
||||
return match.end()
|
||||
|
||||
|
||||
_token_re = re.compile(r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
|
||||
def parse_token(header: str, pos: int, header_name: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Parse a token from ``header`` at the given position.
|
||||
|
||||
Return the token value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
match = _token_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(header_name, "expected token", header, pos)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
_quoted_string_re = re.compile(
|
||||
r'"(?:[\x09\x20-\x21\x23-\x5b\x5d-\x7e]|\\[\x09\x20-\x7e\x80-\xff])*"'
|
||||
)
|
||||
|
||||
|
||||
_unquote_re = re.compile(r"\\([\x09\x20-\x7e\x80-\xff])")
|
||||
|
||||
|
||||
def parse_quoted_string(header: str, pos: int, header_name: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Parse a quoted string from ``header`` at the given position.
|
||||
|
||||
Return the unquoted value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
match = _quoted_string_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected quoted string", header, pos
|
||||
)
|
||||
return _unquote_re.sub(r"\1", match.group()[1:-1]), match.end()
|
||||
|
||||
|
||||
_quotable_re = re.compile(r"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
_quote_re = re.compile(r"([\x22\x5c])")
|
||||
|
||||
|
||||
def build_quoted_string(value: str) -> str:
|
||||
"""
|
||||
Format ``value`` as a quoted string.
|
||||
|
||||
This is the reverse of :func:`parse_quoted_string`.
|
||||
|
||||
"""
|
||||
match = _quotable_re.fullmatch(value)
|
||||
if match is None:
|
||||
raise ValueError("invalid characters for quoted-string encoding")
|
||||
return '"' + _quote_re.sub(r"\\\1", value) + '"'
|
||||
|
||||
|
||||
def parse_list(
|
||||
parse_item: Callable[[str, int, str], Tuple[T, int]],
|
||||
header: str,
|
||||
pos: int,
|
||||
header_name: str,
|
||||
) -> List[T]:
|
||||
"""
|
||||
Parse a comma-separated list from ``header`` at the given position.
|
||||
|
||||
This is appropriate for parsing values with the following grammar:
|
||||
|
||||
1#item
|
||||
|
||||
``parse_item`` parses one item.
|
||||
|
||||
``header`` is assumed not to start or end with whitespace.
|
||||
|
||||
(This function is designed for parsing an entire header value and
|
||||
:func:`~websockets.http.read_headers` strips whitespace from values.)
|
||||
|
||||
Return a list of items.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
# Per https://www.rfc-editor.org/rfc/rfc7230.html#section-7, "a recipient
|
||||
# MUST parse and ignore a reasonable number of empty list elements";
|
||||
# hence while loops that remove extra delimiters.
|
||||
|
||||
# Remove extra delimiters before the first item.
|
||||
while peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
|
||||
items = []
|
||||
while True:
|
||||
# Loop invariant: a item starts at pos in header.
|
||||
item, pos = parse_item(header, pos, header_name)
|
||||
items.append(item)
|
||||
pos = parse_OWS(header, pos)
|
||||
|
||||
# We may have reached the end of the header.
|
||||
if pos == len(header):
|
||||
break
|
||||
|
||||
# There must be a delimiter after each element except the last one.
|
||||
if peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
else:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected comma", header, pos
|
||||
)
|
||||
|
||||
# Remove extra delimiters before the next item.
|
||||
while peek_ahead(header, pos) == ",":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
|
||||
# We may have reached the end of the header.
|
||||
if pos == len(header):
|
||||
break
|
||||
|
||||
# Since we only advance in the header by one character with peek_ahead()
|
||||
# or with the end position of a regex match, we can't overshoot the end.
|
||||
assert pos == len(header)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def parse_connection_option(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> Tuple[ConnectionOption, int]:
|
||||
"""
|
||||
Parse a Connection option from ``header`` at the given position.
|
||||
|
||||
Return the protocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
item, pos = parse_token(header, pos, header_name)
|
||||
return cast(ConnectionOption, item), pos
|
||||
|
||||
|
||||
def parse_connection(header: str) -> List[ConnectionOption]:
|
||||
"""
|
||||
Parse a ``Connection`` header.
|
||||
|
||||
Return a list of HTTP connection options.
|
||||
|
||||
Args
|
||||
header: value of the ``Connection`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_connection_option, header, 0, "Connection")
|
||||
|
||||
|
||||
_protocol_re = re.compile(
|
||||
r"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+(?:/[-!#$%&\'*+.^_`|~0-9a-zA-Z]+)?"
|
||||
)
|
||||
|
||||
|
||||
def parse_upgrade_protocol(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> Tuple[UpgradeProtocol, int]:
|
||||
"""
|
||||
Parse an Upgrade protocol from ``header`` at the given position.
|
||||
|
||||
Return the protocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
match = _protocol_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected protocol", header, pos
|
||||
)
|
||||
return cast(UpgradeProtocol, match.group()), match.end()
|
||||
|
||||
|
||||
def parse_upgrade(header: str) -> List[UpgradeProtocol]:
|
||||
"""
|
||||
Parse an ``Upgrade`` header.
|
||||
|
||||
Return a list of HTTP protocols.
|
||||
|
||||
Args:
|
||||
header: value of the ``Upgrade`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_upgrade_protocol, header, 0, "Upgrade")
|
||||
|
||||
|
||||
def parse_extension_item_param(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> Tuple[ExtensionParameter, int]:
|
||||
"""
|
||||
Parse a single extension parameter from ``header`` at the given position.
|
||||
|
||||
Return a ``(name, value)`` pair and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
# Extract parameter name.
|
||||
name, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
# Extract parameter value, if there is one.
|
||||
value: Optional[str] = None
|
||||
if peek_ahead(header, pos) == "=":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
if peek_ahead(header, pos) == '"':
|
||||
pos_before = pos # for proper error reporting below
|
||||
value, pos = parse_quoted_string(header, pos, header_name)
|
||||
# https://www.rfc-editor.org/rfc/rfc6455.html#section-9.1 says:
|
||||
# the value after quoted-string unescaping MUST conform to
|
||||
# the 'token' ABNF.
|
||||
if _token_re.fullmatch(value) is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "invalid quoted header content", header, pos_before
|
||||
)
|
||||
else:
|
||||
value, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
|
||||
return (name, value), pos
|
||||
|
||||
|
||||
def parse_extension_item(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> Tuple[ExtensionHeader, int]:
|
||||
"""
|
||||
Parse an extension definition from ``header`` at the given position.
|
||||
|
||||
Return an ``(extension name, parameters)`` pair, where ``parameters`` is a
|
||||
list of ``(name, value)`` pairs, and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
# Extract extension name.
|
||||
name, pos = parse_token(header, pos, header_name)
|
||||
pos = parse_OWS(header, pos)
|
||||
# Extract all parameters.
|
||||
parameters = []
|
||||
while peek_ahead(header, pos) == ";":
|
||||
pos = parse_OWS(header, pos + 1)
|
||||
parameter, pos = parse_extension_item_param(header, pos, header_name)
|
||||
parameters.append(parameter)
|
||||
return (cast(ExtensionName, name), parameters), pos
|
||||
|
||||
|
||||
def parse_extension(header: str) -> List[ExtensionHeader]:
|
||||
"""
|
||||
Parse a ``Sec-WebSocket-Extensions`` header.
|
||||
|
||||
Return a list of WebSocket extensions and their parameters in this format::
|
||||
|
||||
[
|
||||
(
|
||||
'extension name',
|
||||
[
|
||||
('parameter name', 'parameter value'),
|
||||
....
|
||||
]
|
||||
),
|
||||
...
|
||||
]
|
||||
|
||||
Parameter values are :obj:`None` when no value is provided.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_extension_item, header, 0, "Sec-WebSocket-Extensions")
|
||||
|
||||
|
||||
parse_extension_list = parse_extension # alias for backwards compatibility
|
||||
|
||||
|
||||
def build_extension_item(
|
||||
name: ExtensionName, parameters: List[ExtensionParameter]
|
||||
) -> str:
|
||||
"""
|
||||
Build an extension definition.
|
||||
|
||||
This is the reverse of :func:`parse_extension_item`.
|
||||
|
||||
"""
|
||||
return "; ".join(
|
||||
[cast(str, name)]
|
||||
+ [
|
||||
# Quoted strings aren't necessary because values are always tokens.
|
||||
name if value is None else f"{name}={value}"
|
||||
for name, value in parameters
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_extension(extensions: Sequence[ExtensionHeader]) -> str:
|
||||
"""
|
||||
Build a ``Sec-WebSocket-Extensions`` header.
|
||||
|
||||
This is the reverse of :func:`parse_extension`.
|
||||
|
||||
"""
|
||||
return ", ".join(
|
||||
build_extension_item(name, parameters) for name, parameters in extensions
|
||||
)
|
||||
|
||||
|
||||
build_extension_list = build_extension # alias for backwards compatibility
|
||||
|
||||
|
||||
def parse_subprotocol_item(
|
||||
header: str, pos: int, header_name: str
|
||||
) -> Tuple[Subprotocol, int]:
|
||||
"""
|
||||
Parse a subprotocol from ``header`` at the given position.
|
||||
|
||||
Return the subprotocol value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
item, pos = parse_token(header, pos, header_name)
|
||||
return cast(Subprotocol, item), pos
|
||||
|
||||
|
||||
def parse_subprotocol(header: str) -> List[Subprotocol]:
|
||||
"""
|
||||
Parse a ``Sec-WebSocket-Protocol`` header.
|
||||
|
||||
Return a list of WebSocket subprotocols.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
return parse_list(parse_subprotocol_item, header, 0, "Sec-WebSocket-Protocol")
|
||||
|
||||
|
||||
parse_subprotocol_list = parse_subprotocol # alias for backwards compatibility
|
||||
|
||||
|
||||
def build_subprotocol(subprotocols: Sequence[Subprotocol]) -> str:
|
||||
"""
|
||||
Build a ``Sec-WebSocket-Protocol`` header.
|
||||
|
||||
This is the reverse of :func:`parse_subprotocol`.
|
||||
|
||||
"""
|
||||
return ", ".join(subprotocols)
|
||||
|
||||
|
||||
build_subprotocol_list = build_subprotocol # alias for backwards compatibility
|
||||
|
||||
|
||||
def validate_subprotocols(subprotocols: Sequence[Subprotocol]) -> None:
|
||||
"""
|
||||
Validate that ``subprotocols`` is suitable for :func:`build_subprotocol`.
|
||||
|
||||
"""
|
||||
if not isinstance(subprotocols, Sequence):
|
||||
raise TypeError("subprotocols must be a list")
|
||||
if isinstance(subprotocols, str):
|
||||
raise TypeError("subprotocols must be a list, not a str")
|
||||
for subprotocol in subprotocols:
|
||||
if not _token_re.fullmatch(subprotocol):
|
||||
raise ValueError(f"invalid subprotocol: {subprotocol}")
|
||||
|
||||
|
||||
def build_www_authenticate_basic(realm: str) -> str:
|
||||
"""
|
||||
Build a ``WWW-Authenticate`` header for HTTP Basic Auth.
|
||||
|
||||
Args:
|
||||
realm: identifier of the protection space.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7617.html#section-2
|
||||
realm = build_quoted_string(realm)
|
||||
charset = build_quoted_string("UTF-8")
|
||||
return f"Basic realm={realm}, charset={charset}"
|
||||
|
||||
|
||||
_token68_re = re.compile(r"[A-Za-z0-9-._~+/]+=*")
|
||||
|
||||
|
||||
def parse_token68(header: str, pos: int, header_name: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Parse a token68 from ``header`` at the given position.
|
||||
|
||||
Return the token value and the new position.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
|
||||
"""
|
||||
match = _token68_re.match(header, pos)
|
||||
if match is None:
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
header_name, "expected token68", header, pos
|
||||
)
|
||||
return match.group(), match.end()
|
||||
|
||||
|
||||
def parse_end(header: str, pos: int, header_name: str) -> None:
|
||||
"""
|
||||
Check that parsing reached the end of header.
|
||||
|
||||
"""
|
||||
if pos < len(header):
|
||||
raise exceptions.InvalidHeaderFormat(header_name, "trailing data", header, pos)
|
||||
|
||||
|
||||
def parse_authorization_basic(header: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse an ``Authorization`` header for HTTP Basic Auth.
|
||||
|
||||
Return a ``(username, password)`` tuple.
|
||||
|
||||
Args:
|
||||
header: value of the ``Authorization`` header.
|
||||
|
||||
Raises:
|
||||
InvalidHeaderFormat: on invalid inputs.
|
||||
InvalidHeaderValue: on unsupported inputs.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7235.html#section-2.1
|
||||
# https://www.rfc-editor.org/rfc/rfc7617.html#section-2
|
||||
scheme, pos = parse_token(header, 0, "Authorization")
|
||||
if scheme.lower() != "basic":
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
"Authorization",
|
||||
f"unsupported scheme: {scheme}",
|
||||
)
|
||||
if peek_ahead(header, pos) != " ":
|
||||
raise exceptions.InvalidHeaderFormat(
|
||||
"Authorization", "expected space after scheme", header, pos
|
||||
)
|
||||
pos += 1
|
||||
basic_credentials, pos = parse_token68(header, pos, "Authorization")
|
||||
parse_end(header, pos, "Authorization")
|
||||
|
||||
try:
|
||||
user_pass = base64.b64decode(basic_credentials.encode()).decode()
|
||||
except binascii.Error:
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected base64-encoded credentials",
|
||||
) from None
|
||||
try:
|
||||
username, password = user_pass.split(":", 1)
|
||||
except ValueError:
|
||||
raise exceptions.InvalidHeaderValue(
|
||||
"Authorization",
|
||||
"expected username:password credentials",
|
||||
) from None
|
||||
|
||||
return username, password
|
||||
|
||||
|
||||
def build_authorization_basic(username: str, password: str) -> str:
|
||||
"""
|
||||
Build an ``Authorization`` header for HTTP Basic Auth.
|
||||
|
||||
This is the reverse of :func:`parse_authorization_basic`.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7617.html#section-2
|
||||
assert ":" not in username
|
||||
user_pass = f"{username}:{password}"
|
||||
basic_credentials = base64.b64encode(user_pass.encode()).decode()
|
||||
return "Basic " + basic_credentials
|
||||
30
venv/lib/python3.11/site-packages/websockets/http.py
Normal file
30
venv/lib/python3.11/site-packages/websockets/http.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from .imports import lazy_import
|
||||
from .version import version as websockets_version
|
||||
|
||||
|
||||
# For backwards compatibility:
|
||||
|
||||
|
||||
lazy_import(
|
||||
globals(),
|
||||
# Headers and MultipleValuesError used to be defined in this module.
|
||||
aliases={
|
||||
"Headers": ".datastructures",
|
||||
"MultipleValuesError": ".datastructures",
|
||||
},
|
||||
deprecated_aliases={
|
||||
"read_request": ".legacy.http",
|
||||
"read_response": ".legacy.http",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["USER_AGENT"]
|
||||
|
||||
|
||||
PYTHON_VERSION = "{}.{}".format(*sys.version_info)
|
||||
USER_AGENT = f"Python/{PYTHON_VERSION} websockets/{websockets_version}"
|
||||
364
venv/lib/python3.11/site-packages/websockets/http11.py
Normal file
364
venv/lib/python3.11/site-packages/websockets/http11.py
Normal file
@ -0,0 +1,364 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import warnings
|
||||
from typing import Callable, Generator, Optional
|
||||
|
||||
from . import datastructures, exceptions
|
||||
|
||||
|
||||
# Maximum total size of headers is around 128 * 8 KiB = 1 MiB.
|
||||
MAX_HEADERS = 128
|
||||
|
||||
# Limit request line and header lines. 8KiB is the most common default
|
||||
# configuration of popular HTTP servers.
|
||||
MAX_LINE = 8192
|
||||
|
||||
# Support for HTTP response bodies is intended to read an error message
|
||||
# returned by a server. It isn't designed to perform large file transfers.
|
||||
MAX_BODY = 2**20 # 1 MiB
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Request:
|
||||
"""
|
||||
WebSocket handshake request.
|
||||
|
||||
Attributes:
|
||||
path: Request path, including optional query.
|
||||
headers: Request headers.
|
||||
"""
|
||||
|
||||
path: str
|
||||
headers: datastructures.Headers
|
||||
# body isn't useful is the context of this library.
|
||||
|
||||
_exception: Optional[Exception] = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Optional[Exception]: # pragma: no cover
|
||||
warnings.warn(
|
||||
"Request.exception is deprecated; "
|
||||
"use ServerConnection.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self._exception
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, Request]:
|
||||
"""
|
||||
Parse a WebSocket handshake request.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The request path isn't URL-decoded or validated in any way.
|
||||
|
||||
The request path and headers are expected to contain only ASCII
|
||||
characters. Other characters are represented with surrogate escapes.
|
||||
|
||||
:meth:`parse` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from the data stream after :meth:`parse` returns.
|
||||
|
||||
Args:
|
||||
read_line: generator-based coroutine that reads a LF-terminated
|
||||
line or raises an exception if there isn't enough data
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a full HTTP request.
|
||||
SecurityError: if the request exceeds a security limit.
|
||||
ValueError: if the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3
|
||||
|
||||
if "Transfer-Encoding" in headers:
|
||||
raise NotImplementedError("transfer codings aren't supported")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
raise ValueError("unsupported request body")
|
||||
|
||||
return cls(path, headers)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket handshake request.
|
||||
|
||||
"""
|
||||
# Since the request line and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {self.path} HTTP/1.1\r\n".encode()
|
||||
request += self.headers.serialize()
|
||||
return request
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Response:
|
||||
"""
|
||||
WebSocket handshake response.
|
||||
|
||||
Attributes:
|
||||
status_code: Response code.
|
||||
reason_phrase: Response reason.
|
||||
headers: Response headers.
|
||||
body: Response body, if any.
|
||||
|
||||
"""
|
||||
|
||||
status_code: int
|
||||
reason_phrase: str
|
||||
headers: datastructures.Headers
|
||||
body: Optional[bytes] = None
|
||||
|
||||
_exception: Optional[Exception] = None
|
||||
|
||||
@property
|
||||
def exception(self) -> Optional[Exception]: # pragma: no cover
|
||||
warnings.warn(
|
||||
"Response.exception is deprecated; "
|
||||
"use ClientConnection.handshake_exc instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self._exception
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls,
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
read_exact: Callable[[int], Generator[None, None, bytes]],
|
||||
read_to_eof: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, Response]:
|
||||
"""
|
||||
Parse a WebSocket handshake response.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The reason phrase and headers are expected to contain only ASCII
|
||||
characters. Other characters are represented with surrogate escapes.
|
||||
|
||||
Args:
|
||||
read_line: generator-based coroutine that reads a LF-terminated
|
||||
line or raises an exception if there isn't enough data.
|
||||
read_exact: generator-based coroutine that reads the requested
|
||||
bytes or raises an exception if there isn't enough data.
|
||||
read_to_eof: generator-based coroutine that reads until the end
|
||||
of the stream.
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a full HTTP response.
|
||||
SecurityError: if the response exceeds a security limit.
|
||||
LookupError: if the response isn't well formatted.
|
||||
ValueError: if the response isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2
|
||||
|
||||
try:
|
||||
status_line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(
|
||||
f"invalid HTTP status code: {d(raw_status_code)}"
|
||||
) from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
|
||||
headers = yield from parse_headers(read_line)
|
||||
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.3.3
|
||||
|
||||
if "Transfer-Encoding" in headers:
|
||||
raise NotImplementedError("transfer codings aren't supported")
|
||||
|
||||
# Since websockets only does GET requests (no HEAD, no CONNECT), all
|
||||
# responses except 1xx, 204, and 304 include a message body.
|
||||
if 100 <= status_code < 200 or status_code == 204 or status_code == 304:
|
||||
body = None
|
||||
else:
|
||||
content_length: Optional[int]
|
||||
try:
|
||||
# MultipleValuesError is sufficiently unlikely that we don't
|
||||
# attempt to handle it. Instead we document that its parent
|
||||
# class, LookupError, may be raised.
|
||||
raw_content_length = headers["Content-Length"]
|
||||
except KeyError:
|
||||
content_length = None
|
||||
else:
|
||||
content_length = int(raw_content_length)
|
||||
|
||||
if content_length is None:
|
||||
try:
|
||||
body = yield from read_to_eof(MAX_BODY)
|
||||
except RuntimeError:
|
||||
raise exceptions.SecurityError(
|
||||
f"body too large: over {MAX_BODY} bytes"
|
||||
)
|
||||
elif content_length > MAX_BODY:
|
||||
raise exceptions.SecurityError(
|
||||
f"body too large: {content_length} bytes"
|
||||
)
|
||||
else:
|
||||
body = yield from read_exact(content_length)
|
||||
|
||||
return cls(status_code, reason, headers, body)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""
|
||||
Serialize a WebSocket handshake response.
|
||||
|
||||
"""
|
||||
# Since the status line and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
response = f"HTTP/1.1 {self.status_code} {self.reason_phrase}\r\n".encode()
|
||||
response += self.headers.serialize()
|
||||
if self.body is not None:
|
||||
response += self.body
|
||||
return response
|
||||
|
||||
|
||||
def parse_headers(
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, datastructures.Headers]:
|
||||
"""
|
||||
Parse HTTP headers.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
Args:
|
||||
read_line: generator-based coroutine that reads a LF-terminated line
|
||||
or raises an exception if there isn't enough data.
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without complete headers.
|
||||
SecurityError: if the request exceeds a security limit.
|
||||
ValueError: if the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = datastructures.Headers()
|
||||
for _ in range(MAX_HEADERS + 1):
|
||||
try:
|
||||
line = yield from parse_line(read_line)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise exceptions.SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def parse_line(
|
||||
read_line: Callable[[int], Generator[None, None, bytes]],
|
||||
) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Parse a single line.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
Args:
|
||||
read_line: generator-based coroutine that reads a LF-terminated line
|
||||
or raises an exception if there isn't enough data.
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a CRLF.
|
||||
SecurityError: if the response exceeds a security limit.
|
||||
|
||||
"""
|
||||
try:
|
||||
line = yield from read_line(MAX_LINE)
|
||||
except RuntimeError:
|
||||
raise exceptions.SecurityError("line too long")
|
||||
# Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
||||
99
venv/lib/python3.11/site-packages/websockets/imports.py
Normal file
99
venv/lib/python3.11/site-packages/websockets/imports.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
|
||||
__all__ = ["lazy_import"]
|
||||
|
||||
|
||||
def import_name(name: str, source: str, namespace: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Import ``name`` from ``source`` in ``namespace``.
|
||||
|
||||
There are two use cases:
|
||||
|
||||
- ``name`` is an object defined in ``source``;
|
||||
- ``name`` is a submodule of ``source``.
|
||||
|
||||
Neither :func:`__import__` nor :func:`~importlib.import_module` does
|
||||
exactly this. :func:`__import__` is closer to the intended behavior.
|
||||
|
||||
"""
|
||||
level = 0
|
||||
while source[level] == ".":
|
||||
level += 1
|
||||
assert level < len(source), "importing from parent isn't supported"
|
||||
module = __import__(source[level:], namespace, None, [name], level)
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
def lazy_import(
|
||||
namespace: Dict[str, Any],
|
||||
aliases: Optional[Dict[str, str]] = None,
|
||||
deprecated_aliases: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Provide lazy, module-level imports.
|
||||
|
||||
Typical use::
|
||||
|
||||
__getattr__, __dir__ = lazy_import(
|
||||
globals(),
|
||||
aliases={
|
||||
"<name>": "<source module>",
|
||||
...
|
||||
},
|
||||
deprecated_aliases={
|
||||
...,
|
||||
}
|
||||
)
|
||||
|
||||
This function defines ``__getattr__`` and ``__dir__`` per :pep:`562`.
|
||||
|
||||
"""
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
if deprecated_aliases is None:
|
||||
deprecated_aliases = {}
|
||||
|
||||
namespace_set = set(namespace)
|
||||
aliases_set = set(aliases)
|
||||
deprecated_aliases_set = set(deprecated_aliases)
|
||||
|
||||
assert not namespace_set & aliases_set, "namespace conflict"
|
||||
assert not namespace_set & deprecated_aliases_set, "namespace conflict"
|
||||
assert not aliases_set & deprecated_aliases_set, "namespace conflict"
|
||||
|
||||
package = namespace["__name__"]
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
assert aliases is not None # mypy cannot figure this out
|
||||
try:
|
||||
source = aliases[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return import_name(name, source, namespace)
|
||||
|
||||
assert deprecated_aliases is not None # mypy cannot figure this out
|
||||
try:
|
||||
source = deprecated_aliases[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{package}.{name} is deprecated",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return import_name(name, source, namespace)
|
||||
|
||||
raise AttributeError(f"module {package!r} has no attribute {name!r}")
|
||||
|
||||
namespace["__getattr__"] = __getattr__
|
||||
|
||||
def __dir__() -> Iterable[str]:
|
||||
return sorted(namespace_set | aliases_set | deprecated_aliases_set)
|
||||
|
||||
namespace["__dir__"] = __dir__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
188
venv/lib/python3.11/site-packages/websockets/legacy/auth.py
Normal file
188
venv/lib/python3.11/site-packages/websockets/legacy/auth.py
Normal file
@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hmac
|
||||
import http
|
||||
from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..headers import build_www_authenticate_basic, parse_authorization_basic
|
||||
from .server import HTTPResponse, WebSocketServerProtocol
|
||||
|
||||
|
||||
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
|
||||
|
||||
Credentials = Tuple[str, str]
|
||||
|
||||
|
||||
def is_credentials(value: Any) -> bool:
|
||||
try:
|
||||
username, password = value
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
WebSocket server protocol that enforces HTTP Basic Auth.
|
||||
|
||||
"""
|
||||
|
||||
realm: str = ""
|
||||
"""
|
||||
Scope of protection.
|
||||
|
||||
If provided, it should contain only ASCII characters because the
|
||||
encoding of non-ASCII characters is undefined.
|
||||
"""
|
||||
|
||||
username: Optional[str] = None
|
||||
"""Username of the authenticated user."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
realm: Optional[str] = None,
|
||||
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if realm is not None:
|
||||
self.realm = realm # shadow class attribute
|
||||
self._check_credentials = check_credentials
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def check_credentials(self, username: str, password: str) -> bool:
|
||||
"""
|
||||
Check whether credentials are authorized.
|
||||
|
||||
This coroutine may be overridden in a subclass, for example to
|
||||
authenticate against a database or an external service.
|
||||
|
||||
Args:
|
||||
username: HTTP Basic Auth username.
|
||||
password: HTTP Basic Auth password.
|
||||
|
||||
Returns:
|
||||
bool: :obj:`True` if the handshake should continue;
|
||||
:obj:`False` if it should fail with a HTTP 401 error.
|
||||
|
||||
"""
|
||||
if self._check_credentials is not None:
|
||||
return await self._check_credentials(username, password)
|
||||
|
||||
return False
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
path: str,
|
||||
request_headers: Headers,
|
||||
) -> Optional[HTTPResponse]:
|
||||
"""
|
||||
Check HTTP Basic Auth and return a HTTP 401 response if needed.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request_headers["Authorization"]
|
||||
except KeyError:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Missing credentials\n",
|
||||
)
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Unsupported credentials\n",
|
||||
)
|
||||
|
||||
if not await self.check_credentials(username, password):
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Invalid credentials\n",
|
||||
)
|
||||
|
||||
self.username = username
|
||||
|
||||
return await super().process_request(path, request_headers)
|
||||
|
||||
|
||||
def basic_auth_protocol_factory(
|
||||
realm: Optional[str] = None,
|
||||
credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None,
|
||||
check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None,
|
||||
create_protocol: Optional[Callable[[Any], BasicAuthWebSocketServerProtocol]] = None,
|
||||
) -> Callable[[Any], BasicAuthWebSocketServerProtocol]:
|
||||
"""
|
||||
Protocol factory that enforces HTTP Basic Auth.
|
||||
|
||||
:func:`basic_auth_protocol_factory` is designed to integrate with
|
||||
:func:`~websockets.server.serve` like this::
|
||||
|
||||
websockets.serve(
|
||||
...,
|
||||
create_protocol=websockets.basic_auth_protocol_factory(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
)
|
||||
)
|
||||
|
||||
Args:
|
||||
realm: indicates the scope of protection. It should contain only ASCII
|
||||
characters because the encoding of non-ASCII characters is
|
||||
undefined. Refer to section 2.2 of :rfc:`7235` for details.
|
||||
credentials: defines hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: defines a coroutine that verifies credentials.
|
||||
This coroutine receives ``username`` and ``password`` arguments
|
||||
and returns a :class:`bool`. One of ``credentials`` or
|
||||
``check_credentials`` must be provided but not both.
|
||||
create_protocol: factory that creates the protocol. By default, this
|
||||
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
|
||||
by a subclass.
|
||||
Raises:
|
||||
TypeError: if the ``credentials`` or ``check_credentials`` argument is
|
||||
wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Credentials, credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(credentials)
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
async def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
if create_protocol is None:
|
||||
# Not sure why mypy cannot figure this out.
|
||||
create_protocol = cast(
|
||||
Callable[[Any], BasicAuthWebSocketServerProtocol],
|
||||
BasicAuthWebSocketServerProtocol,
|
||||
)
|
||||
|
||||
return functools.partial(
|
||||
create_protocol,
|
||||
realm=realm,
|
||||
check_credentials=check_credentials,
|
||||
)
|
||||
717
venv/lib/python3.11/site-packages/websockets/legacy/client.py
Normal file
717
venv/lib/python3.11/site-packages/websockets/legacy/client.py
Normal file
@ -0,0 +1,717 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
from ..datastructures import Headers, HeadersLike
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidMessage,
|
||||
InvalidStatusCode,
|
||||
NegotiationError,
|
||||
RedirectHandshake,
|
||||
SecurityError,
|
||||
)
|
||||
from ..extensions import ClientExtensionFactory, Extension
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http import USER_AGENT
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .handshake import build_request, check_response
|
||||
from .http import read_response
|
||||
from .protocol import WebSocketCommonProtocol
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
|
||||
|
||||
|
||||
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
"""
|
||||
WebSocket client connection.
|
||||
|
||||
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
|
||||
coroutines for receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive incoming messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away). It raises
|
||||
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
|
||||
is closed with any other code.
|
||||
|
||||
See :func:`connect` for the documentation of ``logger``, ``origin``,
|
||||
``extensions``, ``subprotocols``, ``extra_headers``, and
|
||||
``user_agent_header``.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
"""
|
||||
|
||||
is_client = True
|
||||
side = "client"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
origin: Optional[Origin] = None,
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
user_agent_header: Optional[str] = USER_AGENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
super().__init__(logger=logger, **kwargs)
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.extra_headers = extra_headers
|
||||
self.user_agent_header = user_agent_header
|
||||
|
||||
def write_http_request(self, path: str, headers: Headers) -> None:
|
||||
"""
|
||||
Write request line and headers to the HTTP request.
|
||||
|
||||
"""
|
||||
self.path = path
|
||||
self.request_headers = headers
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", path)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
# Since the path and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {path} HTTP/1.1\r\n"
|
||||
request += str(headers)
|
||||
|
||||
self.transport.write(request.encode())
|
||||
|
||||
async def read_http_response(self) -> Tuple[int, Headers]:
|
||||
"""
|
||||
Read status line and headers from the HTTP response.
|
||||
|
||||
If the response contains a body, it may be read from ``self.reader``
|
||||
after this coroutine returns.
|
||||
|
||||
Raises:
|
||||
InvalidMessage: if the HTTP message is malformed or isn't an
|
||||
HTTP/1.1 GET response.
|
||||
|
||||
"""
|
||||
try:
|
||||
status_code, reason, headers = await read_response(self.reader)
|
||||
# Remove this branch when dropping support for Python < 3.8
|
||||
# because CancelledError no longer inherits Exception.
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise InvalidMessage("did not receive a valid HTTP response") from exc
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.response_headers = headers
|
||||
|
||||
return status_code, self.response_headers
|
||||
|
||||
@staticmethod
|
||||
def process_extensions(
|
||||
headers: Headers,
|
||||
available_extensions: Optional[Sequence[ClientExtensionFactory]],
|
||||
) -> List[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
Return the list of accepted extensions.
|
||||
|
||||
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
||||
connection.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
:extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
"""
|
||||
accepted_extensions: List[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values:
|
||||
|
||||
if available_extensions is None:
|
||||
raise InvalidHandshake("no extensions supported")
|
||||
|
||||
parsed_header_values: List[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_header_values:
|
||||
|
||||
for extension_factory in available_extensions:
|
||||
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
@staticmethod
|
||||
def process_subprotocol(
|
||||
headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
|
||||
) -> Optional[Subprotocol]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
Check that it contains exactly one supported subprotocol.
|
||||
|
||||
Return the selected subprotocol.
|
||||
|
||||
"""
|
||||
subprotocol: Optional[Subprotocol] = None
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if header_values:
|
||||
|
||||
if available_subprotocols is None:
|
||||
raise InvalidHandshake("no subprotocols supported")
|
||||
|
||||
parsed_header_values: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
if len(parsed_header_values) > 1:
|
||||
subprotocols = ", ".join(parsed_header_values)
|
||||
raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
|
||||
|
||||
subprotocol = parsed_header_values[0]
|
||||
|
||||
if subprotocol not in available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
origin: Optional[Origin] = None,
|
||||
available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
available_subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the client side of the opening handshake.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server.
|
||||
origin: value of the ``Origin`` header.
|
||||
available_extensions: list of supported extensions, in order in
|
||||
which they should be tried.
|
||||
available_subprotocols: list of supported subprotocols, in order
|
||||
of decreasing preference.
|
||||
extra_headers: arbitrary HTTP headers to add to the request.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: if the handshake fails.
|
||||
|
||||
"""
|
||||
request_headers = Headers()
|
||||
|
||||
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
|
||||
|
||||
if wsuri.user_info:
|
||||
request_headers["Authorization"] = build_authorization_basic(
|
||||
*wsuri.user_info
|
||||
)
|
||||
|
||||
if origin is not None:
|
||||
request_headers["Origin"] = origin
|
||||
|
||||
key = build_request(request_headers)
|
||||
|
||||
if available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in available_extensions
|
||||
]
|
||||
)
|
||||
request_headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(available_subprotocols)
|
||||
request_headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
if self.extra_headers is not None:
|
||||
request_headers.update(self.extra_headers)
|
||||
|
||||
if self.user_agent_header is not None:
|
||||
request_headers.setdefault("User-Agent", self.user_agent_header)
|
||||
|
||||
self.write_http_request(wsuri.resource_name, request_headers)
|
||||
|
||||
status_code, response_headers = await self.read_http_response()
|
||||
if status_code in (301, 302, 303, 307, 308):
|
||||
if "Location" not in response_headers:
|
||||
raise InvalidHeader("Location")
|
||||
raise RedirectHandshake(response_headers["Location"])
|
||||
elif status_code != 101:
|
||||
raise InvalidStatusCode(status_code, response_headers)
|
||||
|
||||
check_response(response_headers, key)
|
||||
|
||||
self.extensions = self.process_extensions(
|
||||
response_headers, available_extensions
|
||||
)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(
|
||||
response_headers, available_subprotocols
|
||||
)
|
||||
|
||||
self.connection_open()
|
||||
|
||||
|
||||
class Connect:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
|
||||
can then be used to send and receive messages.
|
||||
|
||||
:func:`connect` can be used as a asynchronous context manager::
|
||||
|
||||
async with websockets.connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in websockets.connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.ConnectionClosed:
|
||||
continue
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
If an error occurs while establishing the connection, :func:`connect`
|
||||
retries with exponential backoff. The backoff delay starts at three
|
||||
seconds and increases up to one minute.
|
||||
|
||||
If an error occurs in the body of the loop, you can handle the exception
|
||||
and :func:`connect` will reconnect with the next iteration; or you can
|
||||
let the exception bubble up and break out of the loop. This lets you
|
||||
decide which errors trigger a reconnection and which errors are fatal.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
create_protocol: factory for the :class:`asyncio.Protocol` managing
|
||||
the connection; defaults to :class:`WebSocketClientProtocol`; may
|
||||
be set to a wrapper or a subclass to customize connection handling.
|
||||
logger: logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.client")``;
|
||||
see the :doc:`logging guide <../topics/logging>` for details.
|
||||
compression: shortcut that enables the "permessage-deflate" extension
|
||||
by default; may be set to :obj:`None` to disable compression;
|
||||
see the :doc:`compression guide <../topics/compression>` for details.
|
||||
origin: value of the ``Origin`` header. This is useful when connecting
|
||||
to a server that validates the ``Origin`` header to defend against
|
||||
Cross-Site WebSocket Hijacking attacks.
|
||||
extensions: list of supported extensions, in order in which they
|
||||
should be tried.
|
||||
subprotocols: list of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: arbitrary HTTP headers to add to the request.
|
||||
user_agent_header: value of the ``User-Agent`` request header;
|
||||
defaults to ``"Python/x.y.z websockets/X.Y"``;
|
||||
:obj:`None` removes the header.
|
||||
open_timeout: timeout for opening the connection in seconds;
|
||||
:obj:`None` to disable the timeout
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
Any other keyword arguments are passed the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
|
||||
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
|
||||
provided, a TLS context is created
|
||||
with :func:`~ssl.create_default_context`.
|
||||
|
||||
* You can set ``host`` and ``port`` to connect to a different host and
|
||||
port from those found in ``uri``. This only changes the destination of
|
||||
the TCP connection. The host name from ``uri`` is still used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
Returns:
|
||||
WebSocketClientProtocol: WebSocket connection.
|
||||
|
||||
Raises:
|
||||
InvalidURI: if ``uri`` isn't a valid WebSocket URI.
|
||||
InvalidHandshake: if the opening handshake fails.
|
||||
~asyncio.TimeoutError: if the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
MAX_REDIRECTS_ALLOWED = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
create_protocol: Optional[Callable[[Any], WebSocketClientProtocol]] = None,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
compression: Optional[str] = "deflate",
|
||||
origin: Optional[Origin] = None,
|
||||
extensions: Optional[Sequence[ClientExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
extra_headers: Optional[HeadersLike] = None,
|
||||
user_agent_header: Optional[str] = USER_AGENT,
|
||||
open_timeout: Optional[float] = 10,
|
||||
ping_interval: Optional[float] = 20,
|
||||
ping_timeout: Optional[float] = 20,
|
||||
close_timeout: Optional[float] = None,
|
||||
max_size: Optional[int] = 2**20,
|
||||
max_queue: Optional[int] = 2**5,
|
||||
read_limit: int = 2**16,
|
||||
write_limit: int = 2**16,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Backwards compatibility: close_timeout used to be called timeout.
|
||||
timeout: Optional[float] = kwargs.pop("timeout", None)
|
||||
if timeout is None:
|
||||
timeout = 10
|
||||
else:
|
||||
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
|
||||
# If both are specified, timeout is ignored.
|
||||
if close_timeout is None:
|
||||
close_timeout = timeout
|
||||
|
||||
# Backwards compatibility: create_protocol used to be called klass.
|
||||
klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None)
|
||||
if klass is None:
|
||||
klass = WebSocketClientProtocol
|
||||
else:
|
||||
warnings.warn("rename klass to create_protocol", DeprecationWarning)
|
||||
# If both are specified, klass is ignored.
|
||||
if create_protocol is None:
|
||||
create_protocol = klass
|
||||
|
||||
# Backwards compatibility: recv() used to return None on closed connections
|
||||
legacy_recv: bool = kwargs.pop("legacy_recv", False)
|
||||
|
||||
# Backwards compatibility: the loop parameter used to be supported.
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
|
||||
if _loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
loop = _loop
|
||||
warnings.warn("remove loop argument", DeprecationWarning)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
elif kwargs.get("ssl") is not None:
|
||||
raise ValueError(
|
||||
"connect() received a ssl argument for a ws:// URI, "
|
||||
"use a wss:// URI to enable TLS"
|
||||
)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
factory = functools.partial(
|
||||
create_protocol,
|
||||
logger=logger,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
extra_headers=extra_headers,
|
||||
user_agent_header=user_agent_header,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_size=max_size,
|
||||
max_queue=max_queue,
|
||||
read_limit=read_limit,
|
||||
write_limit=write_limit,
|
||||
host=wsuri.host,
|
||||
port=wsuri.port,
|
||||
secure=wsuri.secure,
|
||||
legacy_recv=legacy_recv,
|
||||
loop=_loop,
|
||||
)
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
path: Optional[str] = kwargs.pop("path", None)
|
||||
create_connection = functools.partial(
|
||||
loop.create_unix_connection, factory, path, **kwargs
|
||||
)
|
||||
else:
|
||||
host: Optional[str]
|
||||
port: Optional[int]
|
||||
if kwargs.get("sock") is None:
|
||||
host, port = wsuri.host, wsuri.port
|
||||
else:
|
||||
# If sock is given, host and port shouldn't be specified.
|
||||
host, port = None, None
|
||||
# If host and port are given, override values from the URI.
|
||||
host = kwargs.pop("host", host)
|
||||
port = kwargs.pop("port", port)
|
||||
create_connection = functools.partial(
|
||||
loop.create_connection, factory, host, port, **kwargs
|
||||
)
|
||||
|
||||
self.open_timeout = open_timeout
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
self.logger = logger
|
||||
|
||||
# This is a coroutine function.
|
||||
self._create_connection = create_connection
|
||||
self._uri = uri
|
||||
self._wsuri = wsuri
|
||||
|
||||
def handle_redirect(self, uri: str) -> None:
|
||||
# Update the state of this instance to connect to a new URI.
|
||||
old_uri = self._uri
|
||||
old_wsuri = self._wsuri
|
||||
new_uri = urllib.parse.urljoin(old_uri, uri)
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# Forbid TLS downgrade.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
raise SecurityError("redirect from WSS to WS")
|
||||
|
||||
same_origin = (
|
||||
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
|
||||
)
|
||||
|
||||
# Rewrite the host and port arguments for cross-origin redirects.
|
||||
# This preserves connection overrides with the host and port
|
||||
# arguments if the redirect points to the same host and port.
|
||||
if not same_origin:
|
||||
# Replace the host and port argument passed to the protocol factory.
|
||||
factory = self._create_connection.args[0]
|
||||
factory = functools.partial(
|
||||
factory.func,
|
||||
*factory.args,
|
||||
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
|
||||
)
|
||||
# Replace the host and port argument passed to create_connection.
|
||||
self._create_connection = functools.partial(
|
||||
self._create_connection.func,
|
||||
*(factory, new_wsuri.host, new_wsuri.port),
|
||||
**self._create_connection.keywords,
|
||||
)
|
||||
|
||||
# Set the new WebSocket URI. This suffices for same-origin redirects.
|
||||
self._uri = new_uri
|
||||
self._wsuri = new_wsuri
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
BACKOFF_MIN = 1.92
|
||||
BACKOFF_MAX = 60.0
|
||||
BACKOFF_FACTOR = 1.618
|
||||
BACKOFF_INITIAL = 5
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
# Remove this branch when dropping support for Python < 3.8
|
||||
# because CancelledError no longer inherits Exception.
|
||||
except asyncio.CancelledError: # pragma: no cover
|
||||
raise
|
||||
except Exception:
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
|
||||
if backoff_delay == self.BACKOFF_MIN:
|
||||
initial_delay = random.random() * self.BACKOFF_INITIAL
|
||||
self.logger.info(
|
||||
"! connect failed; reconnecting in %.1f seconds",
|
||||
initial_delay,
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(initial_delay)
|
||||
else:
|
||||
self.logger.info(
|
||||
"! connect failed again; retrying in %d seconds",
|
||||
int(backoff_delay),
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(int(backoff_delay))
|
||||
# Increase delay with truncated exponential backoff.
|
||||
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
|
||||
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
|
||||
continue
|
||||
else:
|
||||
# Connection succeeded - reset backoff delay
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
|
||||
# async with connect(...) as ...:
|
||||
|
||||
async def __aenter__(self) -> WebSocketClientProtocol:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.protocol.close()
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl_timeout__().__await__()
|
||||
|
||||
async def __await_impl_timeout__(self) -> WebSocketClientProtocol:
|
||||
return await asyncio.wait_for(self.__await_impl__(), self.open_timeout)
|
||||
|
||||
async def __await_impl__(self) -> WebSocketClientProtocol:
|
||||
for redirects in range(self.MAX_REDIRECTS_ALLOWED):
|
||||
_transport, _protocol = await self._create_connection()
|
||||
protocol = cast(WebSocketClientProtocol, _protocol)
|
||||
try:
|
||||
await protocol.handshake(
|
||||
self._wsuri,
|
||||
origin=protocol.origin,
|
||||
available_extensions=protocol.available_extensions,
|
||||
available_subprotocols=protocol.available_subprotocols,
|
||||
extra_headers=protocol.extra_headers,
|
||||
)
|
||||
except RedirectHandshake as exc:
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
self.handle_redirect(exc.uri)
|
||||
# Avoid leaking a connected socket when the handshake fails.
|
||||
except (Exception, asyncio.CancelledError):
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
raise
|
||||
else:
|
||||
self.protocol = protocol
|
||||
return protocol
|
||||
else:
|
||||
raise SecurityError("too many redirects")
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
|
||||
connect = Connect
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: Optional[str] = None,
|
||||
uri: str = "ws://localhost/",
|
||||
**kwargs: Any,
|
||||
) -> Connect:
|
||||
"""
|
||||
Similar to :func:`connect`, but for connecting to a Unix socket.
|
||||
|
||||
This function builds upon the event loop's
|
||||
:meth:`~asyncio.loop.create_unix_connection` method.
|
||||
|
||||
It is only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: file system path to the Unix socket.
|
||||
uri: URI of the WebSocket server; the host is used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
"""
|
||||
return connect(uri=uri, path=path, unix=True, **kwargs)
|
||||
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def loop_if_py_lt_38(loop: asyncio.AbstractEventLoop) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper for the removal of the loop argument in Python 3.10.
|
||||
|
||||
"""
|
||||
return {"loop": loop} if sys.version_info[:2] < (3, 8) else {}
|
||||
174
venv/lib/python3.11/site-packages/websockets/legacy/framing.py
Normal file
174
venv/lib/python3.11/site-packages/websockets/legacy/framing.py
Normal file
@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple
|
||||
|
||||
from .. import extensions, frames
|
||||
from ..exceptions import PayloadTooBig, ProtocolError
|
||||
|
||||
|
||||
try:
|
||||
from ..speedups import apply_mask
|
||||
except ImportError: # pragma: no cover
|
||||
from ..utils import apply_mask
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
|
||||
fin: bool
|
||||
opcode: frames.Opcode
|
||||
data: bytes
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
@property
|
||||
def new_frame(self) -> frames.Frame:
|
||||
return frames.Frame(
|
||||
self.opcode,
|
||||
self.data,
|
||||
self.fin,
|
||||
self.rsv1,
|
||||
self.rsv2,
|
||||
self.rsv3,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.new_frame)
|
||||
|
||||
def check(self) -> None:
|
||||
return self.new_frame.check()
|
||||
|
||||
@classmethod
|
||||
async def read(
|
||||
cls,
|
||||
reader: Callable[[int], Awaitable[bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: Optional[int] = None,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> Frame:
|
||||
"""
|
||||
Read a WebSocket frame.
|
||||
|
||||
Args:
|
||||
reader: coroutine that reads exactly the requested number of
|
||||
bytes, unless the end of file is reached.
|
||||
mask: whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: maximum payload size in bytes.
|
||||
extensions: list of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: if the frame exceeds ``max_size``.
|
||||
ProtocolError: if the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
|
||||
# Read the header.
|
||||
data = await reader(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = frames.Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = await reader(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = await reader(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)")
|
||||
if mask:
|
||||
mask_bits = await reader(4)
|
||||
|
||||
# Read the data.
|
||||
data = await reader(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bits)
|
||||
|
||||
new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
new_frame = extension.decode(new_frame, max_size=max_size)
|
||||
|
||||
new_frame.check()
|
||||
|
||||
return cls(
|
||||
new_frame.fin,
|
||||
new_frame.opcode,
|
||||
new_frame.data,
|
||||
new_frame.rsv1,
|
||||
new_frame.rsv2,
|
||||
new_frame.rsv3,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
write: Callable[[bytes], Any],
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Optional[Sequence[extensions.Extension]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write a WebSocket frame.
|
||||
|
||||
Args:
|
||||
frame: frame to write.
|
||||
write: function that writes bytes.
|
||||
mask: whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: list of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# The frame is written in a single call to write in order to prevent
|
||||
# TCP fragmentation. See #68 for details. This also makes it safe to
|
||||
# send frames concurrently from multiple coroutines.
|
||||
write(self.new_frame.serialize(mask=mask, extensions=extensions))
|
||||
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
|
||||
from ..frames import Close, prepare_ctrl as encode_data, prepare_data # noqa
|
||||
|
||||
|
||||
def parse_close(data: bytes) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse the payload from a close frame.
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: close code and reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: if data is ill-formed.
|
||||
UnicodeDecodeError: if the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
close = Close.parse(data)
|
||||
return close.code, close.reason
|
||||
|
||||
|
||||
def serialize_close(code: int, reason: str) -> bytes:
|
||||
"""
|
||||
Serialize the payload for a close frame.
|
||||
|
||||
"""
|
||||
return Close(code, reason).serialize()
|
||||
165
venv/lib/python3.11/site-packages/websockets/legacy/handshake.py
Normal file
165
venv/lib/python3.11/site-packages/websockets/legacy/handshake.py
Normal file
@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from typing import List
|
||||
|
||||
from ..datastructures import Headers, MultipleValuesError
|
||||
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
|
||||
from ..headers import parse_connection, parse_upgrade
|
||||
from ..typing import ConnectionOption, UpgradeProtocol
|
||||
from ..utils import accept_key as accept, generate_key
|
||||
|
||||
|
||||
__all__ = ["build_request", "check_request", "build_response", "check_response"]
|
||||
|
||||
|
||||
def build_request(headers: Headers) -> str:
|
||||
"""
|
||||
Build a handshake request to send to the server.
|
||||
|
||||
Update request headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: handshake request headers.
|
||||
|
||||
Returns:
|
||||
str: ``key`` that must be passed to :func:`check_response`.
|
||||
|
||||
"""
|
||||
key = generate_key()
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
return key
|
||||
|
||||
|
||||
def check_request(headers: Headers) -> str:
|
||||
"""
|
||||
Check a handshake request received from the client.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher GET
|
||||
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
|
||||
are usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
headers: handshake request headers.
|
||||
|
||||
Returns:
|
||||
str: ``key`` that must be passed to :func:`build_response`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: if the handshake request is invalid;
|
||||
then the server must return 400 Bad Request error.
|
||||
|
||||
"""
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", ", ".join(connection))
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
|
||||
|
||||
try:
|
||||
s_w_version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
|
||||
) from exc
|
||||
|
||||
if s_w_version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
|
||||
|
||||
return s_w_key
|
||||
|
||||
|
||||
def build_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Build a handshake response to send to the client.
|
||||
|
||||
Update response headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: handshake response headers.
|
||||
key: returned by :func:`check_request`.
|
||||
|
||||
"""
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept(key)
|
||||
|
||||
|
||||
def check_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Check a handshake response received from the server.
|
||||
|
||||
This function doesn't verify that the response is an HTTP/1.1 or higher
|
||||
response with a 101 status code. These controls are the responsibility of
|
||||
the caller.
|
||||
|
||||
Args:
|
||||
headers: handshake response headers.
|
||||
key: returned by :func:`build_request`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: if the handshake response is invalid.
|
||||
|
||||
"""
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", " ".join(connection))
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
|
||||
) from exc
|
||||
|
||||
if s_w_accept != accept(key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
||||
201
venv/lib/python3.11/site-packages/websockets/legacy/http.py
Normal file
201
venv/lib/python3.11/site-packages/websockets/legacy/http.py
Normal file
@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import SecurityError
|
||||
|
||||
|
||||
__all__ = ["read_request", "read_response"]
|
||||
|
||||
MAX_HEADERS = 256
|
||||
MAX_LINE = 4110
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 GET request and return ``(path, headers)``.
|
||||
|
||||
``path`` isn't URL-decoded or validated in any way.
|
||||
|
||||
``path`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: input to read the request from
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a full HTTP request
|
||||
SecurityError: if the request exceeds a security limit
|
||||
ValueError: if the request isn't well formatted
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return path, headers
|
||||
|
||||
|
||||
async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
|
||||
|
||||
``reason`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the response body because
|
||||
WebSocket handshake responses don't have one. If the response contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: input to read the response from
|
||||
|
||||
Raises:
|
||||
EOFError: if the connection is closed without a full HTTP response
|
||||
SecurityError: if the response exceeds a security limit
|
||||
ValueError: if the response isn't well formatted
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2
|
||||
|
||||
# As in read_request, parsing is simple because a fixed value is expected
|
||||
# for version, status_code is a 3-digit number, and reason can be ignored.
|
||||
|
||||
try:
|
||||
status_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return status_code, reason, headers
|
||||
|
||||
|
||||
async def read_headers(stream: asyncio.StreamReader) -> Headers:
|
||||
"""
|
||||
Read HTTP headers from ``stream``.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
"""
|
||||
# https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = Headers()
|
||||
for _ in range(MAX_HEADERS + 1):
|
||||
try:
|
||||
line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
async def read_line(stream: asyncio.StreamReader) -> bytes:
|
||||
"""
|
||||
Read a single line from ``stream``.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
"""
|
||||
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
|
||||
line = await stream.readline()
|
||||
# Security: this guarantees header values are small (hard-coded = 8 KiB)
|
||||
if len(line) > MAX_LINE:
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
||||
1630
venv/lib/python3.11/site-packages/websockets/legacy/protocol.py
Normal file
1630
venv/lib/python3.11/site-packages/websockets/legacy/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
1185
venv/lib/python3.11/site-packages/websockets/legacy/server.py
Normal file
1185
venv/lib/python3.11/site-packages/websockets/legacy/server.py
Normal file
File diff suppressed because it is too large
Load Diff
517
venv/lib/python3.11/site-packages/websockets/server.py
Normal file
517
venv/lib/python3.11/site-packages/websockets/server.py
Normal file
@ -0,0 +1,517 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import email.utils
|
||||
import http
|
||||
from typing import Generator, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from .connection import CONNECTING, OPEN, SERVER, Connection, State
|
||||
from .datastructures import Headers, MultipleValuesError
|
||||
from .exceptions import (
|
||||
InvalidHandshake,
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
InvalidOrigin,
|
||||
InvalidStatus,
|
||||
InvalidUpgrade,
|
||||
NegotiationError,
|
||||
)
|
||||
from .extensions import Extension, ServerExtensionFactory
|
||||
from .headers import (
|
||||
build_extension,
|
||||
parse_connection,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
parse_upgrade,
|
||||
)
|
||||
from .http11 import Request, Response
|
||||
from .typing import (
|
||||
ConnectionOption,
|
||||
ExtensionHeader,
|
||||
LoggerLike,
|
||||
Origin,
|
||||
Subprotocol,
|
||||
UpgradeProtocol,
|
||||
)
|
||||
from .utils import accept_key
|
||||
|
||||
|
||||
# See #940 for why lazy_import isn't used here for backwards compatibility.
|
||||
from .legacy.server import * # isort:skip # noqa
|
||||
|
||||
|
||||
__all__ = ["ServerConnection"]
|
||||
|
||||
|
||||
class ServerConnection(Connection):
|
||||
"""
|
||||
Sans-I/O implementation of a WebSocket server connection.
|
||||
|
||||
Args:
|
||||
origins: acceptable values of the ``Origin`` header; include
|
||||
:obj:`None` in the list if the lack of an origin is acceptable.
|
||||
This is useful for defending against Cross-Site WebSocket
|
||||
Hijacking attacks.
|
||||
extensions: list of supported extensions, in order in which they
|
||||
should be tried.
|
||||
subprotocols: list of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
state: initial state of the WebSocket connection.
|
||||
max_size: maximum size of incoming messages in bytes;
|
||||
:obj:`None` to disable the limit.
|
||||
logger: logger for this connection;
|
||||
defaults to ``logging.getLogger("websockets.client")``;
|
||||
see the :doc:`logging guide <../topics/logging>` for details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origins: Optional[Sequence[Optional[Origin]]] = None,
|
||||
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
|
||||
subprotocols: Optional[Sequence[Subprotocol]] = None,
|
||||
state: State = CONNECTING,
|
||||
max_size: Optional[int] = 2**20,
|
||||
logger: Optional[LoggerLike] = None,
|
||||
):
|
||||
super().__init__(
|
||||
side=SERVER,
|
||||
state=state,
|
||||
max_size=max_size,
|
||||
logger=logger,
|
||||
)
|
||||
self.origins = origins
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
|
||||
def accept(self, request: Request) -> Response:
|
||||
"""
|
||||
Create a handshake response to accept the connection.
|
||||
|
||||
If the connection cannot be established, the handshake response
|
||||
actually rejects the handshake.
|
||||
|
||||
You must send the handshake response with :meth:`send_response`.
|
||||
|
||||
You can modify it before sending it, for example to add HTTP headers.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request event received from the client.
|
||||
|
||||
Returns:
|
||||
Response: WebSocket handshake response event to send to the client.
|
||||
|
||||
"""
|
||||
try:
|
||||
(
|
||||
accept_header,
|
||||
extensions_header,
|
||||
protocol_header,
|
||||
) = self.process_request(request)
|
||||
except InvalidOrigin as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid origin", exc_info=True)
|
||||
return self.reject(
|
||||
http.HTTPStatus.FORBIDDEN,
|
||||
f"Failed to open a WebSocket connection: {exc}.\n",
|
||||
)
|
||||
except InvalidUpgrade as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid upgrade", exc_info=True)
|
||||
response = self.reject(
|
||||
http.HTTPStatus.UPGRADE_REQUIRED,
|
||||
(
|
||||
f"Failed to open a WebSocket connection: {exc}.\n"
|
||||
f"\n"
|
||||
f"You cannot access a WebSocket server directly "
|
||||
f"with a browser. You need a WebSocket client.\n"
|
||||
),
|
||||
)
|
||||
response.headers["Upgrade"] = "websocket"
|
||||
return response
|
||||
except InvalidHandshake as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
if self.debug:
|
||||
self.logger.debug("! invalid handshake", exc_info=True)
|
||||
return self.reject(
|
||||
http.HTTPStatus.BAD_REQUEST,
|
||||
f"Failed to open a WebSocket connection: {exc}.\n",
|
||||
)
|
||||
except Exception as exc:
|
||||
request._exception = exc
|
||||
self.handshake_exc = exc
|
||||
self.logger.error("opening handshake failed", exc_info=True)
|
||||
return self.reject(
|
||||
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
(
|
||||
"Failed to open a WebSocket connection.\n"
|
||||
"See server log for more information.\n"
|
||||
),
|
||||
)
|
||||
|
||||
headers = Headers()
|
||||
|
||||
headers["Date"] = email.utils.formatdate(usegmt=True)
|
||||
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept_header
|
||||
|
||||
if extensions_header is not None:
|
||||
headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if protocol_header is not None:
|
||||
headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
self.logger.info("connection open")
|
||||
return Response(101, "Switching Protocols", headers)
|
||||
|
||||
def process_request(
|
||||
self, request: Request
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Check a handshake request and negotiate extensions and subprotocol.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher
|
||||
GET request and doesn't check the ``Host`` header. These controls are
|
||||
usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
request: WebSocket handshake request received from the client.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str], Optional[str]]:
|
||||
``Sec-WebSocket-Accept``, ``Sec-WebSocket-Extensions``, and
|
||||
``Sec-WebSocket-Protocol`` headers for the handshake response.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: if the handshake request is invalid;
|
||||
then the server must return 400 Bad Request error.
|
||||
|
||||
"""
|
||||
headers = request.headers
|
||||
|
||||
connection: List[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade(
|
||||
"Connection", ", ".join(connection) if connection else None
|
||||
)
|
||||
|
||||
upgrade: List[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
|
||||
|
||||
try:
|
||||
key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", key)
|
||||
|
||||
try:
|
||||
version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader(
|
||||
"Sec-WebSocket-Version",
|
||||
"more than one Sec-WebSocket-Version header found",
|
||||
) from exc
|
||||
|
||||
if version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", version)
|
||||
|
||||
accept_header = accept_key(key)
|
||||
|
||||
self.origin = self.process_origin(headers)
|
||||
|
||||
extensions_header, self.extensions = self.process_extensions(headers)
|
||||
|
||||
protocol_header = self.subprotocol = self.process_subprotocol(headers)
|
||||
|
||||
return (
|
||||
accept_header,
|
||||
extensions_header,
|
||||
protocol_header,
|
||||
)
|
||||
|
||||
def process_origin(self, headers: Headers) -> Optional[Origin]:
|
||||
"""
|
||||
Handle the Origin HTTP request header.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
Optional[Origin]: origin, if it is acceptable.
|
||||
|
||||
Raises:
|
||||
InvalidOrigin: if the origin isn't acceptable.
|
||||
|
||||
"""
|
||||
# "The user agent MUST NOT include more than one Origin header field"
|
||||
# per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3.
|
||||
try:
|
||||
origin = cast(Optional[Origin], headers.get("Origin"))
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Origin", "more than one Origin header found") from exc
|
||||
if self.origins is not None:
|
||||
if origin not in self.origins:
|
||||
raise InvalidOrigin(origin)
|
||||
return origin
|
||||
|
||||
def process_extensions(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Tuple[Optional[str], List[Extension]]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP request header.
|
||||
|
||||
Accept or reject each extension proposed in the client request.
|
||||
Negotiate parameters for accepted extensions.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
:extension.
|
||||
|
||||
To provide this level of flexibility, for each extension proposed by
|
||||
the client, we check for a match with each extension available in the
|
||||
server configuration. If no match is found, the extension is ignored.
|
||||
|
||||
If several variants of the same extension are proposed by the client,
|
||||
it may be accepted several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
This process doesn't allow the server to reorder extensions. It can
|
||||
only select a subset of the extensions proposed by the client.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], List[Extension]]: ``Sec-WebSocket-Extensions``
|
||||
HTTP response header and list of accepted extensions.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: to abort the handshake with an HTTP 400 error.
|
||||
|
||||
"""
|
||||
response_header_value: Optional[str] = None
|
||||
|
||||
extension_headers: List[ExtensionHeader] = []
|
||||
accepted_extensions: List[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values and self.available_extensions:
|
||||
|
||||
parsed_header_values: List[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, request_params in parsed_header_values:
|
||||
|
||||
for ext_factory in self.available_extensions:
|
||||
|
||||
# Skip non-matching extensions based on their name.
|
||||
if ext_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
response_params, extension = ext_factory.process_request_params(
|
||||
request_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
extension_headers.append((name, response_params))
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the client sent. The extension is declined.
|
||||
|
||||
# Serialize extension header.
|
||||
if extension_headers:
|
||||
response_header_value = build_extension(extension_headers)
|
||||
|
||||
return response_header_value, accepted_extensions
|
||||
|
||||
def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP request header.
|
||||
|
||||
Args:
|
||||
headers: WebSocket handshake request headers.
|
||||
|
||||
Returns:
|
||||
Optional[Subprotocol]: Subprotocol, if one was selected; this is
|
||||
also the value of the ``Sec-WebSocket-Protocol`` response header.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: to abort the handshake with an HTTP 400 error.
|
||||
|
||||
"""
|
||||
subprotocol: Optional[Subprotocol] = None
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if header_values and self.available_subprotocols:
|
||||
|
||||
parsed_header_values: List[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
subprotocol = self.select_subprotocol(
|
||||
parsed_header_values, self.available_subprotocols
|
||||
)
|
||||
|
||||
return subprotocol
|
||||
|
||||
def select_subprotocol(
|
||||
self,
|
||||
client_subprotocols: Sequence[Subprotocol],
|
||||
server_subprotocols: Sequence[Subprotocol],
|
||||
) -> Optional[Subprotocol]:
|
||||
"""
|
||||
Pick a subprotocol among those offered by the client.
|
||||
|
||||
If several subprotocols are supported by the client and the server,
|
||||
the default implementation selects the preferred subprotocols by
|
||||
giving equal value to the priorities of the client and the server.
|
||||
|
||||
If no common subprotocol is supported by the client and the server, it
|
||||
proceeds without a subprotocol.
|
||||
|
||||
This is unlikely to be the most useful implementation in practice, as
|
||||
many servers providing a subprotocol will require that the client uses
|
||||
that subprotocol.
|
||||
|
||||
Args:
|
||||
client_subprotocols: list of subprotocols offered by the client.
|
||||
server_subprotocols: list of subprotocols available on the server.
|
||||
|
||||
Returns:
|
||||
Optional[Subprotocol]: Subprotocol, if a common subprotocol was
|
||||
found.
|
||||
|
||||
"""
|
||||
subprotocols = set(client_subprotocols) & set(server_subprotocols)
|
||||
if not subprotocols:
|
||||
return None
|
||||
priority = lambda p: (
|
||||
client_subprotocols.index(p) + server_subprotocols.index(p)
|
||||
)
|
||||
return sorted(subprotocols, key=priority)[0]
|
||||
|
||||
def reject(
|
||||
self,
|
||||
status: http.HTTPStatus,
|
||||
text: str,
|
||||
) -> Response:
|
||||
"""
|
||||
Create a handshake response to reject the connection.
|
||||
|
||||
A short plain text response is the best fallback when failing to
|
||||
establish a WebSocket connection.
|
||||
|
||||
You must send the handshake response with :meth:`send_response`.
|
||||
|
||||
You can modify it before sending it, for example to alter HTTP headers.
|
||||
|
||||
Args:
|
||||
status: HTTP status code.
|
||||
text: HTTP response body; will be encoded to UTF-8.
|
||||
|
||||
Returns:
|
||||
Response: WebSocket handshake response event to send to the client.
|
||||
|
||||
"""
|
||||
body = text.encode()
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "text/plain; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
response = Response(status.value, status.phrase, headers, body)
|
||||
# When reject() is called from accept(), handshake_exc is already set.
|
||||
# If a user calls reject(), set handshake_exc to guarantee invariant:
|
||||
# "handshake_exc is None if and only if opening handshake succeded."
|
||||
if self.handshake_exc is None:
|
||||
self.handshake_exc = InvalidStatus(response)
|
||||
self.logger.info("connection failed (%d %s)", status.value, status.phrase)
|
||||
return response
|
||||
|
||||
def send_response(self, response: Response) -> None:
|
||||
"""
|
||||
Send a handshake response to the client.
|
||||
|
||||
Args:
|
||||
response: WebSocket handshake response event to send.
|
||||
|
||||
"""
|
||||
if self.debug:
|
||||
code, phrase = response.status_code, response.reason_phrase
|
||||
self.logger.debug("> HTTP/1.1 %d %s", code, phrase)
|
||||
for key, value in response.headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
if response.body is not None:
|
||||
self.logger.debug("> [body] (%d bytes)", len(response.body))
|
||||
|
||||
self.writes.append(response.serialize())
|
||||
|
||||
if response.status_code == 101:
|
||||
assert self.state is CONNECTING
|
||||
self.state = OPEN
|
||||
else:
|
||||
self.send_eof()
|
||||
self.parser = self.discard()
|
||||
next(self.parser) # start coroutine
|
||||
|
||||
def parse(self) -> Generator[None, None, None]:
|
||||
if self.state is CONNECTING:
|
||||
request = yield from Request.parse(self.reader.read_line)
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< GET %s HTTP/1.1", request.path)
|
||||
for key, value in request.headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.events.append(request)
|
||||
|
||||
yield from super().parse()
|
||||
223
venv/lib/python3.11/site-packages/websockets/speedups.c
Normal file
223
venv/lib/python3.11/site-packages/websockets/speedups.c
Normal file
@ -0,0 +1,223 @@
|
||||
/* C implementation of performance sensitive functions. */
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
#include <stdint.h> /* uint8_t, uint32_t, uint64_t */
|
||||
|
||||
#if __ARM_NEON
|
||||
#include <arm_neon.h>
|
||||
#elif __SSE2__
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
|
||||
static const Py_ssize_t MASK_LEN = 4;
|
||||
|
||||
/* Similar to PyBytes_AsStringAndSize, but accepts more types */
|
||||
|
||||
static int
|
||||
_PyBytesLike_AsStringAndSize(PyObject *obj, PyObject **tmp, char **buffer, Py_ssize_t *length)
|
||||
{
|
||||
// This supports bytes, bytearrays, and memoryview objects,
|
||||
// which are common data structures for handling byte streams.
|
||||
// websockets.framing.prepare_data() returns only these types.
|
||||
// If *tmp isn't NULL, the caller gets a new reference.
|
||||
if (PyBytes_Check(obj))
|
||||
{
|
||||
*tmp = NULL;
|
||||
*buffer = PyBytes_AS_STRING(obj);
|
||||
*length = PyBytes_GET_SIZE(obj);
|
||||
}
|
||||
else if (PyByteArray_Check(obj))
|
||||
{
|
||||
*tmp = NULL;
|
||||
*buffer = PyByteArray_AS_STRING(obj);
|
||||
*length = PyByteArray_GET_SIZE(obj);
|
||||
}
|
||||
else if (PyMemoryView_Check(obj))
|
||||
{
|
||||
*tmp = PyMemoryView_GetContiguous(obj, PyBUF_READ, 'C');
|
||||
if (*tmp == NULL)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
Py_buffer *mv_buf;
|
||||
mv_buf = PyMemoryView_GET_BUFFER(*tmp);
|
||||
*buffer = mv_buf->buf;
|
||||
*length = mv_buf->len;
|
||||
}
|
||||
else
|
||||
{
|
||||
PyErr_Format(
|
||||
PyExc_TypeError,
|
||||
"expected a bytes-like object, %.200s found",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* C implementation of websockets.utils.apply_mask */
|
||||
|
||||
static PyObject *
|
||||
apply_mask(PyObject *self, PyObject *args, PyObject *kwds)
|
||||
{
|
||||
|
||||
// In order to support various bytes-like types, accept any Python object.
|
||||
|
||||
static char *kwlist[] = {"data", "mask", NULL};
|
||||
PyObject *input_obj;
|
||||
PyObject *mask_obj;
|
||||
|
||||
// A pointer to a char * + length will be extracted from the data and mask
|
||||
// arguments, possibly via a Py_buffer.
|
||||
|
||||
PyObject *input_tmp = NULL;
|
||||
char *input;
|
||||
Py_ssize_t input_len;
|
||||
PyObject *mask_tmp = NULL;
|
||||
char *mask;
|
||||
Py_ssize_t mask_len;
|
||||
|
||||
// Initialize a PyBytesObject then get a pointer to the underlying char *
|
||||
// in order to avoid an extra memory copy in PyBytes_FromStringAndSize.
|
||||
|
||||
PyObject *result = NULL;
|
||||
char *output;
|
||||
|
||||
// Other variables.
|
||||
|
||||
Py_ssize_t i = 0;
|
||||
|
||||
// Parse inputs.
|
||||
|
||||
if (!PyArg_ParseTupleAndKeywords(
|
||||
args, kwds, "OO", kwlist, &input_obj, &mask_obj))
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (_PyBytesLike_AsStringAndSize(input_obj, &input_tmp, &input, &input_len) == -1)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (_PyBytesLike_AsStringAndSize(mask_obj, &mask_tmp, &mask, &mask_len) == -1)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
if (mask_len != MASK_LEN)
|
||||
{
|
||||
PyErr_SetString(PyExc_ValueError, "mask must contain 4 bytes");
|
||||
goto exit;
|
||||
}
|
||||
|
||||
// Create output.
|
||||
|
||||
result = PyBytes_FromStringAndSize(NULL, input_len);
|
||||
if (result == NULL)
|
||||
{
|
||||
goto exit;
|
||||
}
|
||||
|
||||
// Since we just created result, we don't need error checks.
|
||||
output = PyBytes_AS_STRING(result);
|
||||
|
||||
// Perform the masking operation.
|
||||
|
||||
// Apparently GCC cannot figure out the following optimizations by itself.
|
||||
|
||||
// We need a new scope for MSVC 2010 (non C99 friendly)
|
||||
{
|
||||
#if __ARM_NEON
|
||||
|
||||
// With NEON support, XOR by blocks of 16 bytes = 128 bits.
|
||||
|
||||
Py_ssize_t input_len_128 = input_len & ~15;
|
||||
uint8x16_t mask_128 = vreinterpretq_u8_u32(vdupq_n_u32(*(uint32_t *)mask));
|
||||
|
||||
for (; i < input_len_128; i += 16)
|
||||
{
|
||||
uint8x16_t in_128 = vld1q_u8((uint8_t *)(input + i));
|
||||
uint8x16_t out_128 = veorq_u8(in_128, mask_128);
|
||||
vst1q_u8((uint8_t *)(output + i), out_128);
|
||||
}
|
||||
|
||||
#elif __SSE2__
|
||||
|
||||
// With SSE2 support, XOR by blocks of 16 bytes = 128 bits.
|
||||
|
||||
// Since we cannot control the 16-bytes alignment of input and output
|
||||
// buffers, we rely on loadu/storeu rather than load/store.
|
||||
|
||||
Py_ssize_t input_len_128 = input_len & ~15;
|
||||
__m128i mask_128 = _mm_set1_epi32(*(uint32_t *)mask);
|
||||
|
||||
for (; i < input_len_128; i += 16)
|
||||
{
|
||||
__m128i in_128 = _mm_loadu_si128((__m128i *)(input + i));
|
||||
__m128i out_128 = _mm_xor_si128(in_128, mask_128);
|
||||
_mm_storeu_si128((__m128i *)(output + i), out_128);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Without SSE2 support, XOR by blocks of 8 bytes = 64 bits.
|
||||
|
||||
// We assume the memory allocator aligns everything on 8 bytes boundaries.
|
||||
|
||||
Py_ssize_t input_len_64 = input_len & ~7;
|
||||
uint32_t mask_32 = *(uint32_t *)mask;
|
||||
uint64_t mask_64 = ((uint64_t)mask_32 << 32) | (uint64_t)mask_32;
|
||||
|
||||
for (; i < input_len_64; i += 8)
|
||||
{
|
||||
*(uint64_t *)(output + i) = *(uint64_t *)(input + i) ^ mask_64;
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
// XOR the remainder of the input byte by byte.
|
||||
|
||||
for (; i < input_len; i++)
|
||||
{
|
||||
output[i] = input[i] ^ mask[i & (MASK_LEN - 1)];
|
||||
}
|
||||
|
||||
exit:
|
||||
Py_XDECREF(input_tmp);
|
||||
Py_XDECREF(mask_tmp);
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
static PyMethodDef speedups_methods[] = {
|
||||
{
|
||||
"apply_mask",
|
||||
(PyCFunction)apply_mask,
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
"Apply masking to the data of a WebSocket message.",
|
||||
},
|
||||
{NULL, NULL, 0, NULL}, /* Sentinel */
|
||||
};
|
||||
|
||||
static struct PyModuleDef speedups_module = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"websocket.speedups", /* m_name */
|
||||
"C implementation of performance sensitive functions.",
|
||||
/* m_doc */
|
||||
-1, /* m_size */
|
||||
speedups_methods, /* m_methods */
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC
|
||||
PyInit_speedups(void)
|
||||
{
|
||||
return PyModule_Create(&speedups_module);
|
||||
}
|
||||
Binary file not shown.
151
venv/lib/python3.11/site-packages/websockets/streams.py
Normal file
151
venv/lib/python3.11/site-packages/websockets/streams.py
Normal file
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator
|
||||
|
||||
|
||||
class StreamReader:
|
||||
"""
|
||||
Generator-based stream reader.
|
||||
|
||||
This class doesn't support concurrent calls to :meth:`read_line`,
|
||||
:meth:`read_exact`, or :meth:`read_to_eof`. Make sure calls are
|
||||
serialized.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = bytearray()
|
||||
self.eof = False
|
||||
|
||||
def read_line(self, m: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read a LF-terminated line from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
The return value includes the LF character.
|
||||
|
||||
Args:
|
||||
m: maximum number bytes to read; this is a security limit.
|
||||
|
||||
Raises:
|
||||
EOFError: if the stream ends without a LF.
|
||||
RuntimeError: if the stream ends in more than ``m`` bytes.
|
||||
|
||||
"""
|
||||
n = 0 # number of bytes to read
|
||||
p = 0 # number of bytes without a newline
|
||||
while True:
|
||||
n = self.buffer.find(b"\n", p) + 1
|
||||
if n > 0:
|
||||
break
|
||||
p = len(self.buffer)
|
||||
if p > m:
|
||||
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
|
||||
if self.eof:
|
||||
raise EOFError(f"stream ends after {p} bytes, before end of line")
|
||||
yield
|
||||
if n > m:
|
||||
raise RuntimeError(f"read {n} bytes, expected no more than {m} bytes")
|
||||
r = self.buffer[:n]
|
||||
del self.buffer[:n]
|
||||
return r
|
||||
|
||||
def read_exact(self, n: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read a given number of bytes from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
n: how many bytes to read.
|
||||
|
||||
Raises:
|
||||
EOFError: if the stream ends in less than ``n`` bytes.
|
||||
|
||||
"""
|
||||
assert n >= 0
|
||||
while len(self.buffer) < n:
|
||||
if self.eof:
|
||||
p = len(self.buffer)
|
||||
raise EOFError(f"stream ends after {p} bytes, expected {n} bytes")
|
||||
yield
|
||||
r = self.buffer[:n]
|
||||
del self.buffer[:n]
|
||||
return r
|
||||
|
||||
def read_to_eof(self, m: int) -> Generator[None, None, bytes]:
|
||||
"""
|
||||
Read all bytes from the stream.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
Args:
|
||||
m: maximum number bytes to read; this is a security limit.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if the stream ends in more than ``m`` bytes.
|
||||
|
||||
"""
|
||||
while not self.eof:
|
||||
p = len(self.buffer)
|
||||
if p > m:
|
||||
raise RuntimeError(f"read {p} bytes, expected no more than {m} bytes")
|
||||
yield
|
||||
r = self.buffer[:]
|
||||
del self.buffer[:]
|
||||
return r
|
||||
|
||||
def at_eof(self) -> Generator[None, None, bool]:
|
||||
"""
|
||||
Tell whether the stream has ended and all data was read.
|
||||
|
||||
This is a generator-based coroutine.
|
||||
|
||||
"""
|
||||
while True:
|
||||
if self.buffer:
|
||||
return False
|
||||
if self.eof:
|
||||
return True
|
||||
# When all data was read but the stream hasn't ended, we can't
|
||||
# tell if until either feed_data() or feed_eof() is called.
|
||||
yield
|
||||
|
||||
def feed_data(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream.
|
||||
|
||||
:meth:`feed_data` cannot be called after :meth:`feed_eof`.
|
||||
|
||||
Args:
|
||||
data: data to write.
|
||||
|
||||
Raises:
|
||||
EOFError: if the stream has ended.
|
||||
|
||||
"""
|
||||
if self.eof:
|
||||
raise EOFError("stream ended")
|
||||
self.buffer += data
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
"""
|
||||
End the stream.
|
||||
|
||||
:meth:`feed_eof` cannot be called more than once.
|
||||
|
||||
Raises:
|
||||
EOFError: if the stream has ended.
|
||||
|
||||
"""
|
||||
if self.eof:
|
||||
raise EOFError("stream ended")
|
||||
self.eof = True
|
||||
|
||||
def discard(self) -> None:
|
||||
"""
|
||||
Discard all buffered data, but don't end the stream.
|
||||
|
||||
"""
|
||||
del self.buffer[:]
|
||||
60
venv/lib/python3.11/site-packages/websockets/typing.py
Normal file
60
venv/lib/python3.11/site-packages/websockets/typing.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, NewType, Optional, Tuple, Union
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Data",
|
||||
"LoggerLike",
|
||||
"Origin",
|
||||
"Subprotocol",
|
||||
"ExtensionName",
|
||||
"ExtensionParameter",
|
||||
]
|
||||
|
||||
|
||||
# Public types used in the signature of public APIs
|
||||
|
||||
Data = Union[str, bytes]
|
||||
"""Types supported in a WebSocket message:
|
||||
:class:`str` for a Text_ frame, :class:`bytes` for a Binary_.
|
||||
|
||||
.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
|
||||
.. _Binary : https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
|
||||
|
||||
"""
|
||||
|
||||
|
||||
LoggerLike = Union[logging.Logger, logging.LoggerAdapter]
|
||||
"""Types accepted where a :class:`~logging.Logger` is expected."""
|
||||
|
||||
|
||||
Origin = NewType("Origin", str)
|
||||
"""Value of a ``Origin`` header."""
|
||||
|
||||
|
||||
Subprotocol = NewType("Subprotocol", str)
|
||||
"""Subprotocol in a ``Sec-WebSocket-Protocol`` header."""
|
||||
|
||||
|
||||
ExtensionName = NewType("ExtensionName", str)
|
||||
"""Name of a WebSocket extension."""
|
||||
|
||||
|
||||
ExtensionParameter = Tuple[str, Optional[str]]
|
||||
"""Parameter of a WebSocket extension."""
|
||||
|
||||
|
||||
# Private types
|
||||
|
||||
ExtensionHeader = Tuple[ExtensionName, List[ExtensionParameter]]
|
||||
"""Extension in a ``Sec-WebSocket-Extensions`` header."""
|
||||
|
||||
|
||||
ConnectionOption = NewType("ConnectionOption", str)
|
||||
"""Connection option in a ``Connection`` header."""
|
||||
|
||||
|
||||
UpgradeProtocol = NewType("UpgradeProtocol", str)
|
||||
"""Upgrade protocol in an ``Upgrade`` header."""
|
||||
108
venv/lib/python3.11/site-packages/websockets/uri.py
Normal file
108
venv/lib/python3.11/site-packages/websockets/uri.py
Normal file
@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import urllib.parse
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
__all__ = ["parse_uri", "WebSocketURI"]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WebSocketURI:
|
||||
"""
|
||||
WebSocket URI.
|
||||
|
||||
Attributes:
|
||||
secure: :obj:`True` for a ``wss`` URI, :obj:`False` for a ``ws`` URI.
|
||||
host: Normalized to lower case.
|
||||
port: Always set even if it's the default.
|
||||
path: May be empty.
|
||||
query: May be empty if the URI doesn't include a query component.
|
||||
username: Available when the URI contains `User Information`_.
|
||||
password: Available when the URI contains `User Information`_.
|
||||
|
||||
.. _User Information: https://www.rfc-editor.org/rfc/rfc3986.html#section-3.2.1
|
||||
|
||||
"""
|
||||
|
||||
secure: bool
|
||||
host: str
|
||||
port: int
|
||||
path: str
|
||||
query: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
@property
|
||||
def resource_name(self) -> str:
|
||||
if self.path:
|
||||
resource_name = self.path
|
||||
else:
|
||||
resource_name = "/"
|
||||
if self.query:
|
||||
resource_name += "?" + self.query
|
||||
return resource_name
|
||||
|
||||
@property
|
||||
def user_info(self) -> Optional[Tuple[str, str]]:
|
||||
if self.username is None:
|
||||
return None
|
||||
assert self.password is not None
|
||||
return (self.username, self.password)
|
||||
|
||||
|
||||
# All characters from the gen-delims and sub-delims sets in RFC 3987.
|
||||
DELIMS = ":/?#[]@!$&'()*+,;="
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> WebSocketURI:
|
||||
"""
|
||||
Parse and validate a WebSocket URI.
|
||||
|
||||
Args:
|
||||
uri: WebSocket URI.
|
||||
|
||||
Returns:
|
||||
WebSocketURI: Parsed WebSocket URI.
|
||||
|
||||
Raises:
|
||||
InvalidURI: if ``uri`` isn't a valid WebSocket URI.
|
||||
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(uri)
|
||||
if parsed.scheme not in ["ws", "wss"]:
|
||||
raise exceptions.InvalidURI(uri, "scheme isn't ws or wss")
|
||||
if parsed.hostname is None:
|
||||
raise exceptions.InvalidURI(uri, "hostname isn't provided")
|
||||
if parsed.fragment != "":
|
||||
raise exceptions.InvalidURI(uri, "fragment identifier is meaningless")
|
||||
|
||||
secure = parsed.scheme == "wss"
|
||||
host = parsed.hostname
|
||||
port = parsed.port or (443 if secure else 80)
|
||||
path = parsed.path
|
||||
query = parsed.query
|
||||
username = parsed.username
|
||||
password = parsed.password
|
||||
# urllib.parse.urlparse accepts URLs with a username but without a
|
||||
# password. This doesn't make sense for HTTP Basic Auth credentials.
|
||||
if username is not None and password is None:
|
||||
raise exceptions.InvalidURI(uri, "username provided without password")
|
||||
|
||||
try:
|
||||
uri.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
# Input contains non-ASCII characters.
|
||||
# It must be an IRI. Convert it to a URI.
|
||||
host = host.encode("idna").decode()
|
||||
path = urllib.parse.quote(path, safe=DELIMS)
|
||||
query = urllib.parse.quote(query, safe=DELIMS)
|
||||
if username is not None:
|
||||
assert password is not None
|
||||
username = urllib.parse.quote(username, safe=DELIMS)
|
||||
password = urllib.parse.quote(password, safe=DELIMS)
|
||||
|
||||
return WebSocketURI(secure, host, port, path, query, username, password)
|
||||
51
venv/lib/python3.11/site-packages/websockets/utils.py
Normal file
51
venv/lib/python3.11/site-packages/websockets/utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ["accept_key", "apply_mask"]
|
||||
|
||||
|
||||
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
Generate a random key for the Sec-WebSocket-Key header.
|
||||
|
||||
"""
|
||||
key = secrets.token_bytes(16)
|
||||
return base64.b64encode(key).decode()
|
||||
|
||||
|
||||
def accept_key(key: str) -> str:
|
||||
"""
|
||||
Compute the value of the Sec-WebSocket-Accept header.
|
||||
|
||||
Args:
|
||||
key: value of the Sec-WebSocket-Key header.
|
||||
|
||||
"""
|
||||
sha1 = hashlib.sha1((key + GUID).encode()).digest()
|
||||
return base64.b64encode(sha1).decode()
|
||||
|
||||
|
||||
def apply_mask(data: bytes, mask: bytes) -> bytes:
|
||||
"""
|
||||
Apply masking to the data of a WebSocket message.
|
||||
|
||||
Args:
|
||||
data: data to mask.
|
||||
mask: 4-bytes mask.
|
||||
|
||||
"""
|
||||
if len(mask) != 4:
|
||||
raise ValueError("mask must contain 4 bytes")
|
||||
|
||||
data_int = int.from_bytes(data, sys.byteorder)
|
||||
mask_repeated = mask * (len(data) // 4) + mask[: len(data) % 4]
|
||||
mask_int = int.from_bytes(mask_repeated, sys.byteorder)
|
||||
return (data_int ^ mask_int).to_bytes(len(data), sys.byteorder)
|
||||
78
venv/lib/python3.11/site-packages/websockets/version.py
Normal file
78
venv/lib/python3.11/site-packages/websockets/version.py
Normal file
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = ["tag", "version", "commit"]
|
||||
|
||||
|
||||
# ========= =========== ===================
|
||||
# release development
|
||||
# ========= =========== ===================
|
||||
# tag X.Y X.Y (upcoming)
|
||||
# version X.Y X.Y.dev1+g5678cde
|
||||
# commit X.Y 5678cde
|
||||
# ========= =========== ===================
|
||||
|
||||
|
||||
# When tagging a release, set `released = True`.
|
||||
# After tagging a release, set `released = False` and increment `tag`.
|
||||
|
||||
released = True
|
||||
|
||||
tag = version = commit = "10.4"
|
||||
|
||||
|
||||
if not released: # pragma: no cover
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
def get_version(tag: str) -> str:
|
||||
# Since setup.py executes the contents of src/websockets/version.py,
|
||||
# __file__ can point to either of these two files.
|
||||
file_path = pathlib.Path(__file__)
|
||||
root_dir = file_path.parents[0 if file_path.name == "setup.py" else 2]
|
||||
|
||||
# Read version from git if available. This prevents reading stale
|
||||
# information from src/websockets.egg-info after building a sdist.
|
||||
try:
|
||||
description = subprocess.run(
|
||||
["git", "describe", "--dirty", "--tags", "--long"],
|
||||
capture_output=True,
|
||||
cwd=root_dir,
|
||||
timeout=1,
|
||||
check=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
# subprocess.run raises FileNotFoundError if git isn't on $PATH.
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
pass
|
||||
else:
|
||||
description_re = r"[0-9.]+-([0-9]+)-(g[0-9a-f]{7,}(?:-dirty)?)"
|
||||
match = re.fullmatch(description_re, description)
|
||||
assert match is not None
|
||||
distance, remainder = match.groups()
|
||||
remainder = remainder.replace("-", ".") # required by PEP 440
|
||||
return f"{tag}.dev{distance}+{remainder}"
|
||||
|
||||
# Read version from package metadata if it is installed.
|
||||
try:
|
||||
import importlib.metadata # move up when dropping Python 3.7
|
||||
|
||||
return importlib.metadata.version("websockets")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Avoid crashing if the development version cannot be determined.
|
||||
return f"{tag}.dev0+gunknown"
|
||||
|
||||
version = get_version(tag)
|
||||
|
||||
def get_commit(tag: str, version: str) -> str:
|
||||
# Extract commit from version, falling back to tag if not available.
|
||||
version_re = r"[0-9.]+\.dev[0-9]+\+g([0-9a-f]{7,}|unknown)(?:\.dirty)?"
|
||||
match = re.fullmatch(version_re, version)
|
||||
assert match is not None
|
||||
(commit,) = match.groups()
|
||||
return tag if commit == "unknown" else commit
|
||||
|
||||
commit = get_commit(tag, version)
|
||||
Reference in New Issue
Block a user