java

java 21 Virtual Thread Tcp Socket Server

kimbs0301 2024. 6. 30. 14:46

Java Virtual Thread Tcp Socket Server

java 11 소켓 클래스 java.net.SocketInputStream

java 21 소켓 클래스 sun.nio.ch.NioSocketImpl

JEP 353 (Reimplement the legacy Socket API) https://openjdk.org/jeps/353

에서 Socket API들을 재구현함으로써 코드의 변경 없이 가상 스레드를 사용할 수 있도록 하였다.

golang의 tcp 소켓 서버의 net poll 방식 구현과 유사한 방식으로 구현하였다.

java 1.4 버전 이전의 tcp io socket server와 다른 부분은 연결 클라이언트 처리를 가상 스레드로 한다.

연결 클라이언트 관리, 단편화된 패킷 수신 처리, 버퍼 풀 등 추가 구현이 필요하다.

 

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CountDownLatch;

public class ServerApplication {
    private static final Logger log = LogManager.getLogger(ServerApplication.class);

    public static void main(String[] args) throws Exception {
        String ip = "0.0.0.0";
        int port = 8080;
        log.info("Running on ip {} port {}", ip, port);

        Thread.ofVirtual().name("tcp-listener").start(() -> {
            AtomicInteger nextId = new AtomicInteger();

            try (ServerSocket serverSocket = new ServerSocket()) {
                serverSocket.setReuseAddress(true); // ServerSocket.setReuseAddress
                serverSocket.setReceiveBufferSize(1024 * 10); // Socket.setReceiveBufferSize
                serverSocket.bind(new InetSocketAddress(ip, port), 100);

                while (true) {
                    try {
                        Socket socket = serverSocket.accept();
                        socket.setSendBufferSize(1024 * 10);
                        socket.setReuseAddress(true);
                        socket.setKeepAlive(false);
                        socket.setTcpNoDelay(true);

                        Integer socketId = nextId.incrementAndGet();
                        SocketClient client = new SocketClient(socketId);
                        if (client.register(socket)) {
                            client.inActive();
                            log.info("[{}] connect success", socketId);
                        }
                    } catch (Exception e) {
                        log.error("", e);
                    }
                }
            } catch (Exception e) {
                log.error("", e);
            }
        });

        try {
            System.in.read();
        } catch (Exception e) {}
    }
}

 

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.net.Socket;
import java.util.concurrent.locks.ReentrantLock;

public final class SocketClient {
    private static final Logger log = LogManager.getLogger(SocketClient.class);

    private final Integer socketId;
    private Socket socket;
    private BufferedInputStream in;
    private BufferedOutputStream out;
    private ClientInbound clientInbound;
    private long lastSendTime;
    private long lastReadTime;
    private final ReentrantLock mutex = new ReentrantLock(true);
    private boolean open;

    public SocketClient(Integer socketId) {
        this.socketId = socketId;
    }

    public boolean register(Socket socket) {
        try {
            in = new BufferedInputStream(socket.getInputStream(), 1024 * 8);
        } catch (IOException e) {
            return false;
        }
        try {
            out = new BufferedOutputStream(socket.getOutputStream(), 1024 * 8);
        } catch (IOException e) {
            try {
                in.close();
            } catch (Exception ignore) {

            }
            return false;
        }

        clientInbound = new ClientInbound(this, in, out);

		Thread.ofVirtual().name("client-inbound-" + socketId).start(clientInbound);

        this.socket = socket;
        return true;
    }

    public void inActive() {
        this.open = true;
    }

    public void setLastSendTime() {
        this.lastSendTime = System.currentTimeMillis();
    }

    public void setLastReadTime() {
        this.lastReadTime = System.currentTimeMillis();
    }

    public long getLastSendTime() {
        return lastSendTime;
    }

    public long getLastReadTime() {
        return lastReadTime;
    }

    public void close() {
        mutex.lock();
        if (!open) {
            mutex.unlock();
            return;
        }

        open = false;

        try {
            socket.close(); // Socket close, InputStream close, OutputStream close
        } catch (IOException ignore) {

        }

        if (clientInbound != null) {
            clientInbound.close();
        }

        socket = null;
        clientInbound = null;
        log.info("[{}] close success", socketId);
        mutex.unlock();
    }
}

 

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;

public final class ClientInbound implements Runnable {
    private static final Logger log = LogManager.getLogger(ClientInbound.class);

    private SocketClient socketClient;
    private BufferedInputStream in;
    private BufferedOutputStream out;

    private final static int EOF = -1;

    public ClientInbound(SocketClient socketClient, BufferedInputStream in, BufferedOutputStream out) {
        this.socketClient = socketClient;
        this.in = in;
        this.out = out;
    }

    @Override
    public void run() {
        int readBytesCount = 0;

        while (true) {
            byte[] bytes = new byte[PacketConst.BUFFER_SIZE_DEFAULT];
            try {
                readBytesCount = in.read(bytes);
            } catch (SocketTimeoutException ignore) {
                continue; // next
            } catch (SocketException e) {
                log.info(e.getMessage());
            } catch (IOException e) {
                if ("Socket closed".equals(e.getMessage())) {
                    log.info("Socket closed");
                } else {
                    log.error("", e);
                }
                break; // loop exit
            }
            if (readBytesCount == EOF) {
                log.info("Socket closed 0");
                break; // loop exit
            }

            byte messageType = bytes[0];
            if (messageType != PacketConst.Type.PUSH && messageType != PacketConst.Type.POST_SYNC && messageType != PacketConst.Type.POST_ASYNC) {
                break; // loop exit
            }

            int bodyLength = toInt(bytes);
            if (bodyLength > PacketConst.BUFFER_SIZE_DEFAULT) {
                break; // loop exit
            }
            readBytesCount = 0;

            socketClient.setLastReadTime();

            // 패킷 수신 성공

            if (messageType == PacketConst.Type.POST_SYNC) {
                ByteBuffer buffer = ByteBuffer.allocate(PacketConst.HEAD_SIZE + bodyLength);
                buffer.put(messageType);
                buffer.putInt(bodyLength);
                buffer.put(bytes, PacketConst.HEAD_SIZE, bytes.length - PacketConst.HEAD_SIZE); // body
                buffer.flip();

                send(buffer);
            } else if (messageType == PacketConst.Type.POST_ASYNC) {
                ByteBuffer buffer = ByteBuffer.allocate(PacketConst.HEAD_SIZE + bodyLength);
                buffer.put(messageType);
                buffer.putInt(bodyLength);
                buffer.put(bytes, PacketConst.HEAD_SIZE, bytes.length - PacketConst.HEAD_SIZE); // body
                buffer.flip();

                send(buffer);
            }
        }

        socketClient.close();
    }

    public void send(ByteBuffer buffer) {
        try {
            out.write(buffer.array(), buffer.position(), buffer.limit());
            out.flush();
        } catch (IOException e) {
            log.error("", e);
            return;
        }

        socketClient.setLastSendTime();
    }

    public void close() {
        try {
            in.close();
        } catch (IOException ignore) {

        }
    }

    // packing an array of 4 bytes to an int, big endian, clean code
    private int toInt(byte[] v) {
        return (v[1] & 0xFF) << 24 | (v[2] & 0xFF) << 16 | (v[3] & 0xFF) << 8 | v[4] & 0xFF;
    }
}

 

public interface PacketConst { // Const
    int BUFFER_SIZE_DEFAULT = 1024 * 2;
    //
    int MSG_ID_SIZE = 1;
    int BODY_LENGTH_SIZE = 4;
    //
    int HEAD_SIZE = MSG_ID_SIZE + BODY_LENGTH_SIZE;
    int POST_SYNC_MESSAGE_FROM_ID_SIZE = 4;

    interface Type { // Const
        byte PUSH = 0x07;
        byte POST_SYNC = 0x08;
        byte POST_ASYNC = 0x09;
    }
}