diff --git a/src/java.base/share/classes/java/net/Socket.java b/src/java.base/share/classes/java/net/Socket.java index 83c0dec682c..929e99b5303 100644 --- a/src/java.base/share/classes/java/net/Socket.java +++ b/src/java.base/share/classes/java/net/Socket.java @@ -454,6 +454,7 @@ public class Socket implements java.io.Closeable { throws IOException { Objects.requireNonNull(address); + assert address instanceof InetSocketAddress; // create the SocketImpl and the underlying socket SocketImpl impl = createImpl(); @@ -463,16 +464,13 @@ public class Socket implements java.io.Closeable { this.state = SOCKET_CREATED; try { - if (localAddr != null) + if (localAddr != null) { bind(localAddr); - connect(address); - } catch (IOException | IllegalArgumentException e) { - try { - close(); - } catch (IOException ce) { - e.addSuppressed(ce); } - throw e; + connect(address); + } catch (Throwable throwable) { + closeSuppressingExceptions(throwable); + throw throwable; } } @@ -571,6 +569,10 @@ public class Socket implements java.io.Closeable { /** * Connects this socket to the server. * + *

If the endpoint is an unresolved {@link InetSocketAddress}, or the + * connection cannot be established, then the socket is closed, and an + * {@link IOException} is thrown. + * *

This method is {@linkplain Thread#interrupt() interruptible} in the * following circumstances: *

    @@ -589,6 +591,8 @@ public class Socket implements java.io.Closeable { * @param endpoint the {@code SocketAddress} * @throws IOException if an error occurs during the connection, the socket * is already connected or the socket is closed + * @throws UnknownHostException if the endpoint is an unresolved + * {@link InetSocketAddress} * @throws java.nio.channels.IllegalBlockingModeException * if this socket has an associated channel, * and the channel is in non-blocking mode @@ -605,6 +609,11 @@ public class Socket implements java.io.Closeable { * A timeout of zero is interpreted as an infinite timeout. The connection * will then block until established or an error occurs. * + *

    If the endpoint is an unresolved {@link InetSocketAddress}, the + * connection cannot be established, or the timeout expires before the + * connection is established, then the socket is closed, and an + * {@link IOException} is thrown. + * *

    This method is {@linkplain Thread#interrupt() interruptible} in the * following circumstances: *

      @@ -625,6 +634,8 @@ public class Socket implements java.io.Closeable { * @throws IOException if an error occurs during the connection, the socket * is already connected or the socket is closed * @throws SocketTimeoutException if timeout expires before connecting + * @throws UnknownHostException if the endpoint is an unresolved + * {@link InetSocketAddress} * @throws java.nio.channels.IllegalBlockingModeException * if this socket has an associated channel, * and the channel is in non-blocking mode @@ -644,26 +655,25 @@ public class Socket implements java.io.Closeable { if (isClosed(s)) throw new SocketException("Socket is closed"); if (isConnected(s)) - throw new SocketException("already connected"); + throw new SocketException("Already connected"); if (!(endpoint instanceof InetSocketAddress epoint)) throw new IllegalArgumentException("Unsupported address type"); + if (epoint.isUnresolved()) { + var uhe = new UnknownHostException(epoint.getHostName()); + closeSuppressingExceptions(uhe); + throw uhe; + } + InetAddress addr = epoint.getAddress(); - int port = epoint.getPort(); checkAddress(addr, "connect"); try { getImpl().connect(epoint, timeout); - } catch (SocketTimeoutException e) { - throw e; - } catch (InterruptedIOException e) { - Thread thread = Thread.currentThread(); - if (thread.isVirtual() && thread.isInterrupted()) { - close(); - throw new SocketException("Closed by interrupt"); - } - throw e; + } catch (IOException error) { + closeSuppressingExceptions(error); + throw error; } // connect will bind the socket if not previously bound @@ -1589,6 +1599,14 @@ public class Socket implements java.io.Closeable { return ((Boolean) (getImpl().getOption(SocketOptions.SO_REUSEADDR))).booleanValue(); } + private void closeSuppressingExceptions(Throwable parentException) { + try { + close(); + } catch (IOException exception) { + parentException.addSuppressed(exception); + } + } + /** * Closes this socket. *

      diff --git a/src/java.base/share/classes/sun/nio/ch/Net.java b/src/java.base/share/classes/sun/nio/ch/Net.java index 03dcf04a50f..5c922aff676 100644 --- a/src/java.base/share/classes/sun/nio/ch/Net.java +++ b/src/java.base/share/classes/sun/nio/ch/Net.java @@ -40,6 +40,7 @@ import java.net.StandardProtocolFamily; import java.net.StandardSocketOptions; import java.net.UnknownHostException; import java.nio.channels.AlreadyBoundException; +import java.nio.channels.AlreadyConnectedException; import java.nio.channels.ClosedChannelException; import java.nio.channels.NotYetBoundException; import java.nio.channels.NotYetConnectedException; @@ -166,6 +167,8 @@ public class Net { nx = newSocketException("Socket is not connected"); else if (x instanceof AlreadyBoundException) nx = newSocketException("Already bound"); + else if (x instanceof AlreadyConnectedException) + nx = newSocketException("Already connected"); else if (x instanceof NotYetBoundException) nx = newSocketException("Socket is not bound yet"); else if (x instanceof UnsupportedAddressTypeException) @@ -190,32 +193,12 @@ public class Net { return new SocketException(msg); } - static void translateException(Exception x, - boolean unknownHostForUnresolved) - throws IOException - { + static void translateException(Exception x) throws IOException { if (x instanceof IOException ioe) throw ioe; - // Throw UnknownHostException from here since it cannot - // be thrown as a SocketException - if (unknownHostForUnresolved && - (x instanceof UnresolvedAddressException)) - { - throw new UnknownHostException(); - } translateToSocketException(x); } - static void translateException(Exception x) - throws IOException - { - translateException(x, false); - } - - private static InetSocketAddress getLoopbackAddress(int port) { - return new InetSocketAddress(InetAddress.getLoopbackAddress(), port); - } - private static final InetAddress ANY_LOCAL_INET4ADDRESS; private static final InetAddress ANY_LOCAL_INET6ADDRESS; private static final InetAddress INET4_LOOPBACK_ADDRESS; diff --git a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java index 40bc3156680..7c64d6721e2 100644 --- a/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java +++ b/src/java.base/share/classes/sun/nio/ch/NioSocketImpl.java @@ -599,8 +599,11 @@ public final class NioSocketImpl extends SocketImpl implements PlatformSocketImp } } catch (IOException ioe) { close(); - if (ioe instanceof InterruptedIOException) { + if (ioe instanceof SocketTimeoutException) { throw ioe; + } else if (ioe instanceof InterruptedIOException) { + assert Thread.currentThread().isVirtual(); + throw new SocketException("Closed by interrupt"); } else { throw SocketExceptions.of(ioe, isa); } diff --git a/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java b/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java index cbcfd79378c..d8ed1cfb675 100644 --- a/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java +++ b/src/java.base/share/classes/sun/nio/ch/SocketAdaptor.java @@ -35,6 +35,7 @@ import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketOption; import java.net.StandardSocketOptions; +import java.net.UnknownHostException; import java.nio.channels.SocketChannel; import java.util.Set; @@ -85,6 +86,14 @@ class SocketAdaptor public void connect(SocketAddress remote, int timeout) throws IOException { if (remote == null) throw new IllegalArgumentException("connect: The address can't be null"); + if (remote instanceof InetSocketAddress isa && isa.isUnresolved()) { + if (!sc.isOpen()) + throw new SocketException("Socket is closed"); + if (sc.isConnected()) + throw new SocketException("Already connected"); + close(); + throw new UnknownHostException(remote.toString()); + } if (timeout < 0) throw new IllegalArgumentException("connect: timeout can't be negative"); try { @@ -95,7 +104,7 @@ class SocketAdaptor sc.blockingConnect(remote, Long.MAX_VALUE); } } catch (Exception e) { - Net.translateException(e, true); + Net.translateException(e); } } diff --git a/test/jdk/java/net/Socket/ConnectFailTest.java b/test/jdk/java/net/Socket/ConnectFailTest.java new file mode 100644 index 00000000000..7cc46ce4a4d --- /dev/null +++ b/test/jdk/java/net/Socket/ConnectFailTest.java @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import jdk.test.lib.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; +import java.net.UnknownHostException; +import java.nio.channels.SocketChannel; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/* + * @test + * @bug 8343791 + * @summary verifies that `connect()` failures throw the expected exception and leave socket in the expected state + * @library /test/lib + * @run junit ConnectFailTest + */ +class ConnectFailTest { + + private static final int DEAD_SERVER_PORT = 0xDEAD; + + private static final InetSocketAddress REFUSING_SOCKET_ADDRESS = Utils.refusingEndpoint(); + + private static final InetSocketAddress UNRESOLVED_ADDRESS = + InetSocketAddress.createUnresolved("no.such.host", DEAD_SERVER_PORT); + + @Test + void testUnresolvedAddress() { + assertTrue(UNRESOLVED_ADDRESS.isUnresolved()); + } + + /** + * Verifies that an unbound socket is closed when {@code connect()} fails. + */ + @ParameterizedTest + @MethodSource("sockets") + void testUnboundSocket(Socket socket) throws IOException { + try (socket) { + assertFalse(socket.isBound()); + assertFalse(socket.isConnected()); + assertThrows(IOException.class, () -> socket.connect(REFUSING_SOCKET_ADDRESS)); + assertTrue(socket.isClosed()); + } + } + + /** + * Verifies that a bound socket is closed when {@code connect()} fails. + */ + @ParameterizedTest + @MethodSource("sockets") + void testBoundSocket(Socket socket) throws IOException { + try (socket) { + socket.bind(new InetSocketAddress(0)); + assertTrue(socket.isBound()); + assertFalse(socket.isConnected()); + assertThrows(IOException.class, () -> socket.connect(REFUSING_SOCKET_ADDRESS)); + assertTrue(socket.isClosed()); + } + } + + /** + * Verifies that a connected socket is not closed when {@code connect()} fails. + */ + @ParameterizedTest + @MethodSource("sockets") + void testConnectedSocket(Socket socket) throws Throwable { + try (socket; ServerSocket serverSocket = createEphemeralServerSocket()) { + socket.connect(serverSocket.getLocalSocketAddress()); + try (Socket _ = serverSocket.accept()) { + assertTrue(socket.isBound()); + assertTrue(socket.isConnected()); + SocketException exception = assertThrows( + SocketException.class, + () -> socket.connect(REFUSING_SOCKET_ADDRESS)); + assertEquals("Already connected", exception.getMessage()); + assertFalse(socket.isClosed()); + } + } + } + + /** + * Verifies that an unbound socket is closed when {@code connect()} is invoked using an unresolved address. + */ + @ParameterizedTest + @MethodSource("sockets") + void testUnboundSocketWithUnresolvedAddress(Socket socket) throws IOException { + try (socket) { + assertFalse(socket.isBound()); + assertFalse(socket.isConnected()); + assertThrows(UnknownHostException.class, () -> socket.connect(UNRESOLVED_ADDRESS)); + assertTrue(socket.isClosed()); + } + } + + /** + * Verifies that a bound socket is closed when {@code connect()} is invoked using an unresolved address. + */ + @ParameterizedTest + @MethodSource("sockets") + void testBoundSocketWithUnresolvedAddress(Socket socket) throws IOException { + try (socket) { + socket.bind(new InetSocketAddress(0)); + assertTrue(socket.isBound()); + assertFalse(socket.isConnected()); + assertThrows(UnknownHostException.class, () -> socket.connect(UNRESOLVED_ADDRESS)); + assertTrue(socket.isClosed()); + } + } + + /** + * Verifies that a connected socket is not closed when {@code connect()} is invoked using an unresolved address. + */ + @ParameterizedTest + @MethodSource("sockets") + void testConnectedSocketWithUnresolvedAddress(Socket socket) throws Throwable { + try (socket; ServerSocket serverSocket = createEphemeralServerSocket()) { + socket.connect(serverSocket.getLocalSocketAddress()); + try (Socket _ = serverSocket.accept()) { + assertTrue(socket.isBound()); + assertTrue(socket.isConnected()); + assertThrows(IOException.class, () -> socket.connect(UNRESOLVED_ADDRESS)); + assertFalse(socket.isClosed()); + } + } + } + + static List sockets() throws Exception { + Socket socket = new Socket(); + @SuppressWarnings("resource") + Socket channelSocket = SocketChannel.open().socket(); + return List.of(socket, channelSocket); + } + + private static ServerSocket createEphemeralServerSocket() throws IOException { + return new ServerSocket(0, 0, InetAddress.getLoopbackAddress()); + } + +}