diff --git a/ext/net/ops.rs b/ext/net/ops.rs index eded2c1294..aef4c7e427 100644 --- a/ext/net/ops.rs +++ b/ext/net/ops.rs @@ -74,6 +74,19 @@ impl From for IpAddr { } } +#[cfg(unix)] +impl From for IpAddr { + fn from(addr: tokio::net::unix::SocketAddr) -> Self { + Self { + hostname: addr.as_pathname().map_or_else( + || "unix socket".to_string(), + |p| p.to_string_lossy().into_owned(), + ), + port: 0, // Unix sockets do not have a port + } + } +} + impl From for IpAddr { fn from(addr: TunnelAddr) -> Self { Self { @@ -151,6 +164,9 @@ pub enum NetError { #[class("Busy")] #[error("TCP stream is currently in use")] TcpStreamBusy, + #[class("Busy")] + #[error("Unix stream is currently in use")] + UnixStreamBusy, #[class(generic)] #[error("{0}")] Rustls(#[from] deno_tls::rustls::Error), @@ -165,7 +181,11 @@ pub enum NetError { RootCertStore(deno_error::JsErrorBox), #[class(generic)] #[error("{0}")] - Reunite(tokio::net::tcp::ReuniteError), + ReuniteTcp(tokio::net::tcp::ReuniteError), + #[cfg(unix)] + #[class(generic)] + #[error("{0}")] + ReuniteUnix(tokio::net::unix::ReuniteError), #[class(generic)] #[error("VSOCK is not supported on this platform")] VsockUnsupported, diff --git a/ext/net/ops_tls.rs b/ext/net/ops_tls.rs index 2bdac706b5..0ab98c04c0 100644 --- a/ext/net/ops_tls.rs +++ b/ext/net/ops_tls.rs @@ -46,6 +46,8 @@ use serde::Deserialize; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; use crate::DefaultTlsOptions; use crate::NetPermissions; @@ -99,6 +101,89 @@ enum TlsStreamInner { rd: AsyncRefCell>, wr: AsyncRefCell>, }, + #[cfg(unix)] + Unix { + rd: AsyncRefCell>, + wr: AsyncRefCell>, + }, +} + +#[derive(Debug)] +#[pin_project::pin_project(project = TlsStreamReunitedProject)] +pub enum TlsStreamReunited { + Tcp(#[pin] TlsStream), + #[cfg(unix)] + Unix(#[pin] TlsStream), +} + +impl tokio::io::AsyncRead for TlsStreamReunited { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.project() { + TlsStreamReunitedProject::Tcp(s) => s.poll_read(cx, buf), + #[cfg(unix)] + TlsStreamReunitedProject::Unix(s) => s.poll_read(cx, buf), + } + } +} + +impl tokio::io::AsyncWrite for TlsStreamReunited { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.project() { + TlsStreamReunitedProject::Tcp(s) => s.poll_write(cx, buf), + #[cfg(unix)] + TlsStreamReunitedProject::Unix(s) => s.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + TlsStreamReunitedProject::Tcp(s) => s.poll_flush(cx), + #[cfg(unix)] + TlsStreamReunitedProject::Unix(s) => s.poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + TlsStreamReunitedProject::Tcp(s) => s.poll_shutdown(cx), + #[cfg(unix)] + TlsStreamReunitedProject::Unix(s) => s.poll_shutdown(cx), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + TlsStreamReunited::Tcp(s) => s.is_write_vectored(), + #[cfg(unix)] + TlsStreamReunited::Unix(s) => s.is_write_vectored(), + } + } + + fn poll_write_vectored( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> std::task::Poll> { + match self.project() { + TlsStreamReunitedProject::Tcp(s) => s.poll_write_vectored(cx, bufs), + #[cfg(unix)] + TlsStreamReunitedProject::Unix(s) => s.poll_write_vectored(cx, bufs), + } + } } #[derive(Debug)] @@ -109,6 +194,33 @@ pub struct TlsStreamResource { cancel_handle: CancelHandle, // Only read and handshake ops get canceled. } +macro_rules! match_stream_inner { + ($self:expr, $field:ident, $action:block) => { + match &$self.inner { + TlsStreamInner::Tcp { .. } => { + let mut $field = RcRef::map($self, |r| match &r.inner { + TlsStreamInner::Tcp { $field, .. } => $field, + #[allow(unreachable_patterns)] + _ => unreachable!(), + }) + .borrow_mut() + .await; + $action + } + #[cfg(unix)] + TlsStreamInner::Unix { .. } => { + let mut $field = RcRef::map($self, |r| match &r.inner { + TlsStreamInner::Unix { $field, .. } => $field, + _ => unreachable!(), + }) + .borrow_mut() + .await; + $action + } + } + }; +} + impl TlsStreamResource { pub fn new_tcp( (rd, wr): (TlsStreamRead, TlsStreamWrite), @@ -123,12 +235,32 @@ impl TlsStreamResource { } } - pub fn into_tls_stream(self) -> TlsStream { + #[cfg(unix)] + pub fn new_unix( + (rd, wr): (TlsStreamRead, TlsStreamWrite), + ) -> Self { + Self { + inner: TlsStreamInner::Unix { + rd: AsyncRefCell::new(rd), + wr: AsyncRefCell::new(wr), + }, + handshake_info: RefCell::new(None), + cancel_handle: Default::default(), + } + } + + pub fn into_tls_stream(self) -> TlsStreamReunited { match self.inner { TlsStreamInner::Tcp { rd, wr } => { let read_half = rd.into_inner(); let write_half = wr.into_inner(); - read_half.unsplit(write_half) + TlsStreamReunited::Tcp(read_half.unsplit(write_half)) + } + #[cfg(unix)] + TlsStreamInner::Unix { rd, wr } => { + let read_half = rd.into_inner(); + let write_half = wr.into_inner(); + TlsStreamReunited::Unix(read_half.unsplit(write_half)) } } } @@ -149,37 +281,28 @@ impl TlsStreamResource { self: Rc, data: &mut [u8], ) -> Result { - let mut rd = RcRef::map(&self, |r| match r.inner { - TlsStreamInner::Tcp { ref rd, .. } => rd, - }) - .borrow_mut() - .await; let cancel_handle = RcRef::map(&self, |r| &r.cancel_handle); - rd.read(data).try_or_cancel(cancel_handle).await + match_stream_inner!(self, rd, { + rd.read(data).try_or_cancel(cancel_handle).await + }) } pub async fn write( self: Rc, data: &[u8], ) -> Result { - let mut wr = RcRef::map(&self, |r| match r.inner { - TlsStreamInner::Tcp { ref wr, .. } => wr, + match_stream_inner!(self, wr, { + let nwritten = wr.write(data).await?; + wr.flush().await?; + Ok(nwritten) }) - .borrow_mut() - .await; - let nwritten = wr.write(data).await?; - wr.flush().await?; - Ok(nwritten) } pub async fn shutdown(self: Rc) -> Result<(), std::io::Error> { - let mut wr = RcRef::map(&self, |r| match r.inner { - TlsStreamInner::Tcp { ref wr, .. } => wr, + match_stream_inner!(self, wr, { + wr.shutdown().await?; + Ok(()) }) - .borrow_mut() - .await; - wr.shutdown().await?; - Ok(()) } pub async fn handshake( @@ -189,20 +312,18 @@ impl TlsStreamResource { return Ok(tls_info.clone()); } - let mut wr = RcRef::map(self, |r| match r.inner { - TlsStreamInner::Tcp { ref wr, .. } => wr, - }) - .borrow_mut() - .await; let cancel_handle = RcRef::map(self, |r| &r.cancel_handle); - let handshake = wr.handshake().try_or_cancel(cancel_handle).await?; + let tls_info = match_stream_inner!(self, wr, { + let handshake = wr.handshake().try_or_cancel(cancel_handle).await?; + + let alpn_protocol = handshake.alpn.map(|alpn| alpn.into()); + let peer_certificates = handshake.peer_certificates.clone(); + TlsHandshakeInfo { + alpn_protocol, + peer_certificates, + } + }); - let alpn_protocol = handshake.alpn.map(|alpn| alpn.into()); - let peer_certificates = handshake.peer_certificates.clone(); - let tls_info = TlsHandshakeInfo { - alpn_protocol, - peer_certificates, - }; self.handshake_info.replace(Some(tls_info.clone())); Ok(tls_info) } @@ -349,22 +470,6 @@ where .root_cert_store() .map_err(NetError::RootCertStore)?; - let resource_rc = state - .borrow_mut() - .resource_table - .take::(rid) - .map_err(NetError::Resource)?; - // This TCP connection might be used somewhere else. If it's the case, we cannot proceed with the - // process of starting a TLS connection on top of this TCP connection, so we just return a Busy error. - // See also: https://github.com/denoland/deno/pull/16242 - let resource = - Rc::try_unwrap(resource_rc).map_err(|_| NetError::TcpStreamBusy)?; - let (read_half, write_half) = resource.into_inner(); - let tcp_stream = read_half.reunite(write_half).map_err(NetError::Reunite)?; - - let local_addr = tcp_stream.local_addr()?; - let remote_addr = tcp_stream.peer_addr()?; - let tls_null = TlsKeysHolder::from(TlsKeys::Null); let key_pair = key_pair.unwrap_or(&tls_null); let mut tls_config = create_client_config( @@ -381,20 +486,68 @@ where } let tls_config = Arc::new(tls_config); - let tls_stream = TlsStream::new_client_side( - tcp_stream, - ClientConnection::new(tls_config, hostname_dns)?, - TLS_BUFFER_SIZE, - ); + let resource_table = &mut state.borrow_mut().resource_table; - let rid = { - let mut state_ = state.borrow_mut(); - state_ - .resource_table - .add(TlsStreamResource::new_tcp(tls_stream.into_split())) - }; + let r = resource_table + .take::(rid) + .map_err(NetError::Resource); + if let Ok(resource_rc) = r { + // This TCP connection might be used somewhere else. If it's the case, we cannot proceed with the + // process of starting a TLS connection on top of this TCP connection, so we just return a Busy error. + // See also: https://github.com/denoland/deno/pull/16242 + let resource = + Rc::try_unwrap(resource_rc).map_err(|_| NetError::TcpStreamBusy)?; + let (read_half, write_half) = resource.into_inner(); + let tcp_stream = read_half + .reunite(write_half) + .map_err(NetError::ReuniteTcp)?; - Ok((rid, IpAddr::from(local_addr), IpAddr::from(remote_addr))) + let local_addr = tcp_stream.local_addr()?; + let remote_addr = tcp_stream.peer_addr()?; + + let tls_stream = TlsStream::new_client_side( + tcp_stream, + ClientConnection::new(tls_config, hostname_dns)?, + TLS_BUFFER_SIZE, + ); + + let rid = { + resource_table.add(TlsStreamResource::new_tcp(tls_stream.into_split())) + }; + + return Ok((rid, IpAddr::from(local_addr), IpAddr::from(remote_addr))); + } + + #[cfg(unix)] + if let Ok(resource_rc) = + resource_table.take::(rid) + { + // This UNIX socket might be used somewhere else. + let resource = + Rc::try_unwrap(resource_rc).map_err(|_| NetError::UnixStreamBusy)?; + let (read_half, write_half) = resource.into_inner(); + let unix_stream = read_half + .reunite(write_half) + .map_err(NetError::ReuniteUnix)?; + let local_addr = unix_stream.local_addr()?; + let remote_addr = unix_stream.peer_addr()?; + + let tls_stream = TlsStream::new_client_side( + unix_stream, + ClientConnection::new(tls_config, hostname_dns)?, + TLS_BUFFER_SIZE, + ); + + let rid = { + resource_table.add(TlsStreamResource::new_unix(tls_stream.into_split())) + }; + + return Ok((rid, IpAddr::from(local_addr), IpAddr::from(remote_addr))); + } + + Err(NetError::Resource( + deno_core::error::ResourceError::BadResourceId, + )) } #[op2(async, stack_trace)] diff --git a/ext/net/raw.rs b/ext/net/raw.rs index aa21bf73db..7819d0c3b1 100644 --- a/ext/net/raw.rs +++ b/ext/net/raw.rs @@ -13,6 +13,7 @@ use deno_error::JsErrorBox; use crate::io::TcpStreamResource; use crate::ops_tls::TlsStreamResource; +use crate::ops_tls::TlsStreamReunited; pub trait NetworkStreamTrait: Into { type Resource; @@ -462,7 +463,15 @@ pub fn take_network_stream_resource( let resource = Rc::try_unwrap(resource_rc) .map_err(|_| TakeNetworkStreamError::TlsBusy)?; let tls_stream = resource.into_tls_stream(); - return Ok(NetworkStream::Tls(tls_stream)); + + match tls_stream { + TlsStreamReunited::Tcp(tcp_stream) => { + return Ok(NetworkStream::Tls(tcp_stream)); + } + // TODO(bartlomieju): support unix sockets here + #[allow(unreachable_patterns)] + _ => todo!(), + } } #[cfg(unix)] diff --git a/ext/node/polyfills/_tls_wrap.ts b/ext/node/polyfills/_tls_wrap.ts index 60f1295c63..f2b738070a 100644 --- a/ext/node/polyfills/_tls_wrap.ts +++ b/ext/node/polyfills/_tls_wrap.ts @@ -37,6 +37,7 @@ import { import { startTlsInternal } from "ext:deno_net/02_tls.js"; import { internals } from "ext:core/mod.js"; import { op_tls_canonicalize_ipv4_address } from "ext:core/ops"; +import console from "node:console"; const kConnectOptions = Symbol("connect-options"); const kIsVerified = Symbol("verified"); @@ -89,7 +90,10 @@ export class TLSSocket extends net.Socket { ssl: any; _start() { - this[kHandle].afterConnectTls(); + this.connecting = true; + if (this[kHandle] && this[kHandle][kStreamBaseField]) { + this[kHandle].afterConnectTls?.(); + } } constructor(socket: any, opts: any = kEmptyObject) { @@ -164,15 +168,18 @@ export class TLSSocket extends net.Socket { // Set `afterConnectTls` hook. This is called in the `afterConnect` method of net.Socket handle.afterConnectTls = async () => { + handle.afterConnectTls = undefined; options.hostname ??= undefined; // coerce to undefined if null, startTls expects hostname to be undefined if (tlssock._needsSockInitWorkaround) { // skips the TLS handshake for @npmcli/agent as it's handled by // onSocket handler of ClientRequest object. - tlssock.emit("secure"); + tlssock.emit("secureConnect"); tlssock.removeListener("end", onConnectEnd); return; } + console.log("startTlsInternal", handle[kStreamBaseField]); + console.log("start tls", options.isServer ? "[server]" : "[client]"); try { const conn = await startTlsInternal( handle[kStreamBaseField], @@ -190,19 +197,25 @@ export class TLSSocket extends net.Socket { // operation emit the error. } + console.log("done tls", options.isServer ? "[server]" : "[client]"); + // Assign the TLS connection to the handle and resume reading. handle[kStreamBaseField] = conn; handle.upgrading = false; + tlssock.connecting = false; if (!handle.pauseOnCreate) { handle.readStart(); } resolve(); - tlssock.emit("secure"); + tlssock.emit("connect"); + tlssock.emit("ready"); + tlssock.emit("secureConnect"); tlssock.removeListener("end", onConnectEnd); - } catch { + } catch (e) { // TODO(kt3k): Handle this + console.log("handle.afterConnecTls error", e); } }; diff --git a/ext/node/polyfills/net.ts b/ext/node/polyfills/net.ts index ea3f74bf32..2378edc8ed 100644 --- a/ext/node/polyfills/net.ts +++ b/ext/node/polyfills/net.ts @@ -41,6 +41,7 @@ import { newAsyncId, ownerSymbol, } from "ext:deno_node/internal/async_hooks.ts"; +import { kStreamBaseField } from "ext:deno_node/internal_binding/stream_wrap.ts"; import { ERR_INVALID_ADDRESS_FAMILY, ERR_INVALID_ARG_TYPE, @@ -380,7 +381,9 @@ function _afterConnect( // Deno specific: run tls handshake if it's from a tls socket // This swaps the handle[kStreamBaseField] from TcpConn to TlsConn - if (typeof handle.afterConnectTls === "function") { + if ( + typeof handle.afterConnectTls === "function" && handle[kStreamBaseField] + ) { handle.afterConnectTls(); } @@ -1287,7 +1290,6 @@ export class Socket extends Duplex { } this.on("end", _onReadableStreamEnd); - _initSocketHandle(this); // If we have a handle, then start the flow of data into the diff --git a/runtime/ops/http.rs b/runtime/ops/http.rs index aa88668f42..a7407707d6 100644 --- a/runtime/ops/http.rs +++ b/runtime/ops/http.rs @@ -9,6 +9,7 @@ use deno_core::op2; use deno_http::http_create_conn_resource; use deno_net::io::TcpStreamResource; use deno_net::ops_tls::TlsStreamResource; +use deno_net::ops_tls::TlsStreamReunited; pub const UNSTABLE_FEATURE_NAME: &str = "http"; @@ -75,7 +76,13 @@ fn op_http_start( let resource = Rc::try_unwrap(resource_rc) .map_err(|_| HttpStartError::TlsStreamInUse)?; let tls_stream = resource.into_tls_stream(); - let addr = tls_stream.local_addr()?; + let addr = match tls_stream { + TlsStreamReunited::Tcp(ref s) => s.local_addr()?, + // TODO(bartlomieju): handle Unix socket + #[allow(unreachable_patterns)] + _ => todo!(), + }; + return Ok(http_create_conn_resource(state, tls_stream, addr, "https")); } diff --git a/tests/unit_node/tls_test.ts b/tests/unit_node/tls_test.ts index 79fe3ab23b..9758db58be 100644 --- a/tests/unit_node/tls_test.ts +++ b/tests/unit_node/tls_test.ts @@ -13,6 +13,7 @@ import * as tls from "node:tls"; import * as net from "node:net"; import * as stream from "node:stream"; import { execCode } from "../unit/test_util.ts"; +import console from "node:console"; const tlsTestdataDir = fromFileUrl( new URL("../testdata/tls", import.meta.url), @@ -82,6 +83,7 @@ Host: localhost Connection: close `); + const chunk = Promise.withResolvers(); conn.on("data", (received) => { conn.destroy(); @@ -304,3 +306,139 @@ Deno.test({ } assertEquals(new TextDecoder().decode(stdout), ""); }); + +// TODO(bartlomieju): this test currently doesn't pass, because server-side +// socket doesn't handle TLS correctly. +Deno.test({ + name: "tls.connect over unix socket works", + ignore: true, + // ignore: Deno.build.os === "windows", + permissions: { read: true, write: true }, +}, async () => { + const socketPath = "/tmp/tls_unix_test.sock"; + + try { + await Deno.remove(socketPath); + } catch { + // pass + } + + let serverError: unknown = null; + let clientError: unknown = null; + + const { promise: serverReady, resolve: resolveServerReady } = Promise + .withResolvers(); + const { promise: testComplete, resolve: resolveTestComplete } = Promise + .withResolvers(); + const { promise: clientDataReceived, resolve: resolveClientDataReceived } = + Promise.withResolvers(); + + const netServer = net.createServer((rawSocket) => { + try { + console.log("before create"); + const secureSocket = new tls.TLSSocket(rawSocket, { + key, + cert, + isServer: true, + }); + + secureSocket.on("secureConnect", () => { + console.log("secure socket on secureConnect"); + secureSocket.write("hello from server"); + }); + + secureSocket.on("data", (data) => { + console.log( + "secure socket on data", + data.byteLength, + data.toString(), + ); + assertEquals(data.toString(), "hello from client"); + secureSocket.end(); + }); + + secureSocket.on("close", () => { + console.log("secure socket on close"); + resolveTestComplete(); + }); + + secureSocket.on("error", (err) => { + console.log("secure socket on error"); + serverError = err; + resolveTestComplete(); + }); + } catch (err) { + serverError = err; + resolveTestComplete(); + } + }); + + netServer.on("error", (err) => { + serverError = err; + resolveTestComplete(); + }); + + netServer.listen(socketPath, () => { + resolveServerReady(); + }); + console.log("before server ready"); + await serverReady; + console.log("after server ready"); + try { + const rawSocket = net.connect(socketPath); + + const secureSocket = tls.connect({ + socket: rawSocket, + rejectUnauthorized: false, + }); + + rawSocket.on("error", (err) => { + console.log("raw socket on err", err); + clientError = err; + resolveTestComplete(); + }); + + secureSocket.on("secureConnect", () => { + console.log("secure socket on secureConnect"); + secureSocket.write("hello from client"); + }); + + secureSocket.on("data", (data) => { + console.log("secure socket on data"); + resolveClientDataReceived(data.toString()); + }); + + secureSocket.on("error", (err) => { + console.log("secure socket on error"); + clientError = err; + resolveTestComplete(); + }); + + console.log("before client data received"); + const receivedData = await clientDataReceived; + console.log("after client data received"); + assertEquals(receivedData, "hello from server"); + console.log("before test complete"); + await testComplete; + console.log("after test complete"); + if (serverError) { + console.error("Server error:", serverError); + } + if (clientError) { + console.error("Client error:", clientError); + } + + secureSocket.destroy(); + } catch (err) { + clientError = err; + console.error("Test setup error:", err); + } + + netServer.close(); + + try { + await Deno.remove(socketPath); + } catch { + // pass + } +});