diff --git a/test/test_http_proxy.py b/test/test_http_proxy.py index 97ceef9f9..191b6baf8 100644 --- a/test/test_http_proxy.py +++ b/test/test_http_proxy.py @@ -7,6 +7,7 @@ import os import random import ssl import threading +import time from http.server import BaseHTTPRequestHandler from socketserver import BaseRequestHandler, ThreadingTCPServer @@ -124,6 +125,30 @@ class HTTPSProxyHandler(HTTPProxyHandler): super().__init__(request, *args, **kwargs) +class WebSocketProxyHandler(BaseRequestHandler): + def __init__(self, *args, proxy_info=None, **kwargs): + self.proxy_info = proxy_info + super().__init__(*args, **kwargs) + + def handle(self): + import websockets.sync.server + protocol = websockets.ServerProtocol() + connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) + connection.handshake() + connection.send(json.dumps(self.proxy_info)) + connection.close() + + +class WebSocketSecureProxyHandler(WebSocketProxyHandler): + def __init__(self, request, *args, proxy_info=None, **kwargs): + self.proxy_info = proxy_info + certfn = os.path.join(TEST_DIR, 'testcert.pem') + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(certfn, None) + request = SSLTransport(request, ssl_context=sslctx, server_side=True) + super().__init__(request, *args, **kwargs) + + class HTTPConnectProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin): protocol_version = 'HTTP/1.1' default_request_version = 'HTTP/1.1' @@ -215,9 +240,30 @@ class HTTPProxyHTTPSTestContext(HTTPProxyTestContext): return json.loads(handler.send(request).read().decode()) +class HTTPProxyWebSocketTestContext(HTTPProxyTestContext): + REQUEST_HANDLER_CLASS = WebSocketProxyHandler + REQUEST_PROTO = 'ws' + + def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs): + request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs) + handler.validate(request) + ws = handler.send(request) + ws.send('proxy_info') + socks_info = ws.recv() + ws.close() + return json.loads(socks_info) + + +class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext): + REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler + REQUEST_PROTO = 'wss' + + CTX_MAP = { 'http': HTTPProxyHTTPTestContext, 'https': HTTPProxyHTTPSTestContext, + 'ws': HTTPProxyWebSocketTestContext, + 'wss': HTTPProxyWebSocketSecureTestContext, } @@ -289,6 +335,8 @@ class TestHTTPProxy: 'handler,ctx', [ ('Requests', 'https'), ('CurlCFFI', 'https'), + ('Websockets', 'ws'), + ('Websockets', 'wss') ], indirect=True) class TestHTTPConnectProxy: def test_http_connect_no_auth(self, handler, ctx): diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index 6e235b0c6..776662c3a 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -25,7 +25,16 @@ from .websocket import WebSocketRequestHandler, WebSocketResponse from ..compat import functools from ..dependencies import websockets from ..socks import ProxyError as SocksProxyError -from ..utils import int_or_none +from ..utils import int_or_none, extract_basic_auth +import io +import urllib.parse +import base64 + +from http.client import HTTPResponse, HTTPConnection, HTTPSConnection + +from urllib3.util.ssltransport import SSLTransport + +from ..utils.networking import HTTPHeaderDict if not websockets: raise ImportError('websockets is not installed') @@ -98,13 +107,14 @@ class WebsocketsRH(WebSocketRequestHandler): https://github.com/python-websockets/websockets """ _SUPPORTED_URL_SCHEMES = ('wss', 'ws') - _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https') _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) RH_NAME = 'websockets' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__logging_handlers = {} + self.verbose = True for name in ('websockets.client', 'websockets.server'): logger = logging.getLogger(name) handler = logging.StreamHandler(stream=sys.stdout) @@ -125,6 +135,38 @@ class WebsocketsRH(WebSocketRequestHandler): for name, handler in self.__logging_handlers.items(): logging.getLogger(name).removeHandler(handler) + def _make_sock(self, proxy, url, timeout): + create_conn_kwargs = { + 'source_address': (self.source_address, 0) if self.source_address else None, + 'timeout': timeout + } + parsed_url = parse_uri(url) + parsed_proxy_url = urllib.parse.urlparse(proxy) + if proxy: + if parsed_proxy_url.scheme.startswith('socks'): + socks_proxy_options = make_socks_proxy_opts(proxy) + return create_connection( + address=(socks_proxy_options['addr'], socks_proxy_options['port']), + _create_socket_func=functools.partial( + create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options), + **create_conn_kwargs + ) + + elif parsed_proxy_url.scheme.startswith('http'): + return create_http_connect_conn( + proxy_url=proxy, + url=url, + timeout=timeout, + ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None, + source_address=self.source_address, + username=parsed_proxy_url.username, + password=parsed_proxy_url.password, + ) + return create_connection( + address=(parsed_url.host, parsed_url.port), + **create_conn_kwargs + ) + def _send(self, request): timeout = self._calculate_timeout(request) headers = self._merge_headers(request.headers) @@ -134,33 +176,15 @@ class WebsocketsRH(WebSocketRequestHandler): if cookie_header: headers['cookie'] = cookie_header - wsuri = parse_uri(request.url) - create_conn_kwargs = { - 'source_address': (self.source_address, 0) if self.source_address else None, - 'timeout': timeout - } proxy = select_proxy(request.url, self._get_proxies(request)) try: - if proxy: - socks_proxy_options = make_socks_proxy_opts(proxy) - sock = create_connection( - address=(socks_proxy_options['addr'], socks_proxy_options['port']), - _create_socket_func=functools.partial( - create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), - **create_conn_kwargs - ) - else: - sock = create_connection( - address=(wsuri.host, wsuri.port), - **create_conn_kwargs - ) conn = websockets.sync.client.connect( - sock=sock, + sock=self._make_sock(proxy, request.url, timeout), uri=request.url, additional_headers=headers, open_timeout=timeout, user_agent_header=None, - ssl_context=self._make_sslcontext() if wsuri.secure else None, + ssl_context=self._make_sslcontext() if parse_uri(request.url).secure else None, close_timeout=0, # not ideal, but prevents yt-dlp hanging ) return WebsocketsResponseAdapter(conn, url=request.url) @@ -185,3 +209,75 @@ class WebsocketsRH(WebSocketRequestHandler): ) from e except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: raise TransportError(cause=e) from e + + +class NoCloseHTTPResponse(HTTPResponse): + def begin(self): + super().begin() + # Revert the default behavior of closing the connection after reading the response + if not self._check_close() and not self.chunked and self.length is None: + self.will_close = False + +class CustomSSLTransport(SSLTransport): + def setsockopt(self, *args, **kwargs): + self.socket.setsockopt(*args, **kwargs) + def shutdown(self, *args, **kwargs): + self.socket.shutdown(*args, **kwargs) + +def create_http_connect_conn( + proxy_url, + url, + timeout=None, + ssl_context=None, + source_address=None, + headers=None, + username=None, + password=None, +): + + # todo: handle ipv6 host + proxy_headers = HTTPHeaderDict({ + **(headers or {}), + }) + + if username is not None or password is not None: + proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode( + f'{username}:{password}'.encode('utf-8')).decode('utf-8') + + proxy_url_parsed = urllib.parse.urlparse(proxy_url) + request_url_parsed = parse_uri(url) + + conn = HTTPConnection(proxy_url_parsed.hostname, port=proxy_url_parsed.port, timeout=timeout) + conn.response_class = NoCloseHTTPResponse + + if hasattr(conn, '_create_connection'): + conn._create_connection = create_connection + + if source_address is not None: + conn.source_address = (source_address, 0) + + conn.debuglevel=2 + try: + conn.connect() + if ssl_context: + conn.sock = CustomSSLTransport(conn.sock, ssl_context, server_hostname=proxy_url_parsed.hostname) + + conn.request(method='CONNECT', url=f'{request_url_parsed.host}:{request_url_parsed.port}', headers=proxy_headers) + response = conn.getresponse() + except OSError as e: + conn.close() + raise TransportError('Unable to connect to proxy', cause=e) from e + + if response.status == 200: + return conn.sock + elif response.status == 407: + conn.close() + raise ProxyError('Got HTTP Error 407 with CONNECT: Proxy Authentication Required') + else: + conn.close() + res_adapter = Response( + fp=io.BytesIO(b''), + url=proxy_url, headers=response.headers, + status=response.status, + reason=response.reason) + raise HTTPError(response=res_adapter)