/*
 * Decompiled with CFR 0.152.
 */
package org.apache.zookeeper.server;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.ChannelGroupFuture;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.OptionalSslHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.DefaultEventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.X509KeyManager;
import javax.net.ssl.X509TrustManager;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.common.ClientX509Util;
import org.apache.zookeeper.common.NettyUtils;
import org.apache.zookeeper.common.SSLContextAndOptions;
import org.apache.zookeeper.common.X509Exception;
import org.apache.zookeeper.server.NettyServerCnxn;
import org.apache.zookeeper.server.ServerCnxn;
import org.apache.zookeeper.server.ServerCnxnFactory;
import org.apache.zookeeper.server.ZooKeeperServer;
import org.apache.zookeeper.server.auth.ProviderRegistry;
import org.apache.zookeeper.server.auth.X509AuthenticationProvider;
import org.apache.zookeeper.server.quorum.QuorumPeerConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NettyServerCnxnFactory
extends ServerCnxnFactory {
    private static final Logger LOG = LoggerFactory.getLogger(NettyServerCnxnFactory.class);
    public static final String PORT_UNIFICATION_KEY = "zookeeper.client.portUnification";
    private final boolean shouldUsePortUnification;
    private static final byte TLS_HANDSHAKE_RECORD_TYPE = 22;
    private final ServerBootstrap bootstrap;
    private Channel parentChannel;
    private final ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns", new DefaultEventExecutor());
    private final Map<InetAddress, Set<NettyServerCnxn>> ipMap = new HashMap<InetAddress, Set<NettyServerCnxn>>();
    private InetSocketAddress localAddress;
    private int maxClientCnxns = 60;
    private final ClientX509Util x509Util;
    private static final AttributeKey<NettyServerCnxn> CONNECTION_ATTRIBUTE = AttributeKey.valueOf("NettyServerCnxn");
    private static final AtomicReference<ByteBufAllocator> TEST_ALLOCATOR = new AtomicReference<Object>(null);
    CnxnChannelHandler channelHandler = new CnxnChannelHandler();
    private boolean killed;

    private ServerBootstrap configureBootstrapAllocator(ServerBootstrap bootstrap) {
        ByteBufAllocator testAllocator = TEST_ALLOCATOR.get();
        if (testAllocator != null) {
            return ((ServerBootstrap)bootstrap.option(ChannelOption.ALLOCATOR, testAllocator)).childOption(ChannelOption.ALLOCATOR, testAllocator);
        }
        return bootstrap;
    }

    NettyServerCnxnFactory() {
        this.x509Util = new ClientX509Util();
        boolean usePortUnification = Boolean.getBoolean(PORT_UNIFICATION_KEY);
        LOG.info("{}={}", (Object)PORT_UNIFICATION_KEY, (Object)usePortUnification);
        if (usePortUnification) {
            try {
                QuorumPeerConfig.configureSSLAuth();
            }
            catch (QuorumPeerConfig.ConfigException e) {
                LOG.error("unable to set up SslAuthProvider, turning off client port unification", e);
                usePortUnification = false;
            }
        }
        this.shouldUsePortUnification = usePortUnification;
        EventLoopGroup bossGroup = NettyUtils.newNioOrEpollEventLoopGroup(NettyUtils.getClientReachableLocalInetAddressCount());
        EventLoopGroup workerGroup = NettyUtils.newNioOrEpollEventLoopGroup();
        ServerBootstrap bootstrap = ((ServerBootstrap)((ServerBootstrap)new ServerBootstrap().group(bossGroup, workerGroup).channel(NettyUtils.nioOrEpollServerSocketChannel())).option(ChannelOption.SO_REUSEADDR, true)).childOption(ChannelOption.TCP_NODELAY, true).childOption(ChannelOption.SO_LINGER, -1).childHandler(new ChannelInitializer<SocketChannel>(){

            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
                ChannelPipeline pipeline = ch.pipeline();
                if (NettyServerCnxnFactory.this.secure) {
                    NettyServerCnxnFactory.this.initSSL(pipeline, false);
                } else if (NettyServerCnxnFactory.this.shouldUsePortUnification) {
                    NettyServerCnxnFactory.this.initSSL(pipeline, true);
                }
                pipeline.addLast("servercnxnfactory", (ChannelHandler)NettyServerCnxnFactory.this.channelHandler);
            }
        });
        this.bootstrap = this.configureBootstrapAllocator(bootstrap);
        this.bootstrap.validate();
    }

    private synchronized void initSSL(ChannelPipeline p2, boolean supportPlaintext) throws X509Exception, KeyManagementException, NoSuchAlgorithmException {
        SslContext nettySslContext;
        String authProviderProp = System.getProperty(this.x509Util.getSslAuthProviderProperty());
        if (authProviderProp == null) {
            SSLContextAndOptions sslContextAndOptions = this.x509Util.getDefaultSSLContextAndOptions();
            nettySslContext = sslContextAndOptions.createNettyJdkSslContext(sslContextAndOptions.getSSLContext(), false);
        } else {
            SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
            X509AuthenticationProvider authProvider = (X509AuthenticationProvider)ProviderRegistry.getProvider(System.getProperty(this.x509Util.getSslAuthProviderProperty(), "x509"));
            if (authProvider == null) {
                LOG.error("Auth provider not found: {}", (Object)authProviderProp);
                throw new X509Exception.SSLContextException("Could not create SSLContext with specified auth provider: " + authProviderProp);
            }
            sslContext.init(new X509KeyManager[]{authProvider.getKeyManager()}, new X509TrustManager[]{authProvider.getTrustManager()}, null);
            nettySslContext = this.x509Util.getDefaultSSLContextAndOptions().createNettyJdkSslContext(sslContext, false);
        }
        if (supportPlaintext) {
            p2.addLast("ssl", (ChannelHandler)new DualModeSslHandler(nettySslContext));
            LOG.debug("dual mode SSL handler added for channel: {}", (Object)p2.channel());
        } else {
            p2.addLast("ssl", (ChannelHandler)nettySslContext.newHandler(p2.channel().alloc()));
            LOG.debug("SSL handler added for channel: {}", (Object)p2.channel());
        }
    }

    @Override
    public void closeAll() {
        if (LOG.isDebugEnabled()) {
            LOG.debug("closeAll()");
        }
        int length2 = this.cnxns.size();
        for (ServerCnxn cnxn : this.cnxns) {
            try {
                cnxn.close();
            }
            catch (Exception e) {
                LOG.warn("Ignoring exception closing cnxn sessionid 0x" + Long.toHexString(cnxn.getSessionId()), e);
            }
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("allChannels size:" + this.allChannels.size() + " cnxns size:" + length2);
        }
    }

    @Override
    public boolean closeSession(long sessionId) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("closeSession sessionid:0x" + sessionId);
        }
        for (ServerCnxn cnxn : this.cnxns) {
            if (cnxn.getSessionId() != sessionId) continue;
            try {
                cnxn.close();
            }
            catch (Exception e) {
                LOG.warn("exception during session close", e);
            }
            return true;
        }
        return false;
    }

    @Override
    public void configure(InetSocketAddress addr2, int maxClientCnxns, boolean secure) throws IOException {
        this.configureSaslLogin();
        this.localAddress = addr2;
        this.maxClientCnxns = maxClientCnxns;
        this.secure = secure;
    }

    @Override
    public int getMaxClientCnxnsPerHost() {
        return this.maxClientCnxns;
    }

    @Override
    public void setMaxClientCnxnsPerHost(int max2) {
        this.maxClientCnxns = max2;
    }

    @Override
    public int getLocalPort() {
        return this.localAddress.getPort();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void join() throws InterruptedException {
        NettyServerCnxnFactory nettyServerCnxnFactory = this;
        synchronized (nettyServerCnxnFactory) {
            while (!this.killed) {
                this.wait();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void shutdown() {
        NettyServerCnxnFactory nettyServerCnxnFactory = this;
        synchronized (nettyServerCnxnFactory) {
            if (this.killed) {
                LOG.info("already shutdown {}", (Object)this.localAddress);
                return;
            }
        }
        LOG.info("shutdown called {}", (Object)this.localAddress);
        this.x509Util.close();
        if (this.login != null) {
            this.login.shutdown();
        }
        EventLoopGroup bossGroup = this.bootstrap.config().group();
        EventLoopGroup workerGroup = this.bootstrap.config().childGroup();
        if (this.parentChannel != null) {
            ChannelFuture parentCloseFuture = this.parentChannel.close();
            if (bossGroup != null) {
                parentCloseFuture.addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)future -> bossGroup.shutdownGracefully()));
            }
            this.closeAll();
            ChannelGroupFuture allChannelsCloseFuture = this.allChannels.close();
            if (workerGroup != null) {
                allChannelsCloseFuture.addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)future -> workerGroup.shutdownGracefully()));
            }
        } else {
            if (bossGroup != null) {
                bossGroup.shutdownGracefully();
            }
            if (workerGroup != null) {
                workerGroup.shutdownGracefully();
            }
        }
        if (this.zkServer != null) {
            this.zkServer.shutdown();
        }
        NettyServerCnxnFactory nettyServerCnxnFactory2 = this;
        synchronized (nettyServerCnxnFactory2) {
            this.killed = true;
            this.notifyAll();
        }
    }

    @Override
    public void start() {
        LOG.info("binding to port {}", (Object)this.localAddress);
        this.parentChannel = this.bootstrap.bind(this.localAddress).syncUninterruptibly().channel();
        this.localAddress = (InetSocketAddress)this.parentChannel.localAddress();
        LOG.info("bound to port " + this.getLocalPort());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void reconfigure(InetSocketAddress addr2) {
        try (Channel oldChannel = this.parentChannel;){
            LOG.info("binding to port {}", (Object)addr2);
            this.parentChannel = this.bootstrap.bind(addr2).syncUninterruptibly().channel();
            this.localAddress = (InetSocketAddress)this.parentChannel.localAddress();
            LOG.info("bound to port " + this.getLocalPort());
        }
    }

    @Override
    public void startup(ZooKeeperServer zks, boolean startServer) throws IOException, InterruptedException {
        this.start();
        this.setZooKeeperServer(zks);
        if (startServer) {
            zks.startdata();
            zks.startup();
        }
    }

    @Override
    public Iterable<ServerCnxn> getConnections() {
        return this.cnxns;
    }

    @Override
    public InetSocketAddress getLocalAddress() {
        return this.localAddress;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void addCnxn(NettyServerCnxn cnxn) {
        this.cnxns.add(cnxn);
        Map<InetAddress, Set<NettyServerCnxn>> map = this.ipMap;
        synchronized (map) {
            InetAddress addr2 = ((InetSocketAddress)cnxn.getChannel().remoteAddress()).getAddress();
            Set<NettyServerCnxn> s2 = this.ipMap.get(addr2);
            if (s2 == null) {
                s2 = new HashSet<NettyServerCnxn>();
                this.ipMap.put(addr2, s2);
            }
            s2.add(cnxn);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void removeCnxnFromIpMap(NettyServerCnxn cnxn, InetAddress remoteAddress) {
        Map<InetAddress, Set<NettyServerCnxn>> map = this.ipMap;
        synchronized (map) {
            Set<NettyServerCnxn> s2 = this.ipMap.get(remoteAddress);
            if (s2 != null) {
                s2.remove(cnxn);
                if (s2.isEmpty()) {
                    this.ipMap.remove(remoteAddress);
                }
                return;
            }
        }
        LOG.error("Unexpected null set for remote address {} when removing cnxn {}", (Object)remoteAddress, (Object)cnxn);
    }

    @Override
    public void resetAllConnectionStats() {
        for (ServerCnxn c : this.cnxns) {
            c.resetStats();
        }
    }

    @Override
    public Iterable<Map<String, Object>> getAllConnectionInfo(boolean brief) {
        HashSet<Map<String, Object>> info = new HashSet<Map<String, Object>>();
        for (ServerCnxn c : this.cnxns) {
            info.add(c.getConnectionInfo(brief));
        }
        return info;
    }

    static void setTestAllocator(ByteBufAllocator allocator) {
        TEST_ALLOCATOR.set(allocator);
    }

    static void clearTestAllocator() {
        TEST_ALLOCATOR.set(null);
    }

    final class CertificateVerifier
    implements GenericFutureListener<Future<Channel>> {
        private final SslHandler sslHandler;
        private final NettyServerCnxn cnxn;

        CertificateVerifier(SslHandler sslHandler, NettyServerCnxn cnxn) {
            this.sslHandler = sslHandler;
            this.cnxn = cnxn;
        }

        @Override
        public void operationComplete(Future<Channel> future) {
            if (future.isSuccess()) {
                SSLEngine eng;
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Successful handshake with session 0x{}", (Object)Long.toHexString(this.cnxn.getSessionId()));
                }
                if ((eng = this.sslHandler.engine()).getNeedClientAuth() || eng.getWantClientAuth()) {
                    SSLSession session = eng.getSession();
                    try {
                        this.cnxn.setClientCertificateChain(session.getPeerCertificates());
                    }
                    catch (SSLPeerUnverifiedException e) {
                        if (eng.getNeedClientAuth()) {
                            LOG.error("Error getting peer certificates", e);
                            this.cnxn.close();
                            return;
                        }
                        Channel futureChannel = future.getNow();
                        NettyServerCnxnFactory.this.allChannels.add(Objects.requireNonNull(futureChannel));
                        NettyServerCnxnFactory.this.addCnxn(this.cnxn);
                        return;
                    }
                    catch (Exception e) {
                        LOG.error("Error getting peer certificates", e);
                        this.cnxn.close();
                        return;
                    }
                    String authProviderProp = System.getProperty(NettyServerCnxnFactory.this.x509Util.getSslAuthProviderProperty(), "x509");
                    X509AuthenticationProvider authProvider = (X509AuthenticationProvider)ProviderRegistry.getProvider(authProviderProp);
                    if (authProvider == null) {
                        LOG.error("X509 Auth provider not found: {}", (Object)authProviderProp);
                        this.cnxn.close();
                        return;
                    }
                    if (KeeperException.Code.OK != authProvider.handleAuthentication(this.cnxn, null)) {
                        LOG.error("Authentication failed for session 0x{}", (Object)Long.toHexString(this.cnxn.getSessionId()));
                        this.cnxn.close();
                        return;
                    }
                }
                Channel futureChannel = future.getNow();
                NettyServerCnxnFactory.this.allChannels.add(Objects.requireNonNull(futureChannel));
                NettyServerCnxnFactory.this.addCnxn(this.cnxn);
            } else {
                LOG.error("Unsuccessful handshake with session 0x{}", (Object)Long.toHexString(this.cnxn.getSessionId()));
                this.cnxn.close();
            }
        }
    }

    @ChannelHandler.Sharable
    class CnxnChannelHandler
    extends ChannelDuplexHandler {
        private final GenericFutureListener<Future<Void>> onWriteCompletedTracer = f -> LOG.trace("write {}", (Object)(f.isSuccess() ? "complete" : "failed"));

        CnxnChannelHandler() {
        }

        @Override
        public void channelActive(ChannelHandlerContext ctx) throws Exception {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Channel active {}", (Object)ctx.channel());
            }
            NettyServerCnxn cnxn = new NettyServerCnxn(ctx.channel(), NettyServerCnxnFactory.this.zkServer, NettyServerCnxnFactory.this);
            ctx.channel().attr(CONNECTION_ATTRIBUTE).set(cnxn);
            if (NettyServerCnxnFactory.this.secure) {
                SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
                Future<Channel> handshakeFuture = sslHandler.handshakeFuture();
                handshakeFuture.addListener(new CertificateVerifier(sslHandler, cnxn));
            } else if (!NettyServerCnxnFactory.this.shouldUsePortUnification) {
                NettyServerCnxnFactory.this.allChannels.add(ctx.channel());
                NettyServerCnxnFactory.this.addCnxn(cnxn);
            }
        }

        @Override
        public void channelInactive(ChannelHandlerContext ctx) throws Exception {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Channel inactive {}", (Object)ctx.channel());
            }
            NettyServerCnxnFactory.this.allChannels.remove(ctx.channel());
            NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null);
            if (cnxn != null) {
                if (LOG.isTraceEnabled()) {
                    LOG.trace("Channel inactive caused close {}", (Object)cnxn);
                }
                cnxn.close();
            }
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause2) throws Exception {
            LOG.warn("Exception caught", cause2);
            NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null);
            if (cnxn != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Closing {}", (Object)cnxn);
                }
                cnxn.close();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
            try {
                if (evt == NettyServerCnxn.AutoReadEvent.ENABLE) {
                    LOG.debug("Received AutoReadEvent.ENABLE");
                    NettyServerCnxn cnxn = (NettyServerCnxn)ctx.channel().attr(CONNECTION_ATTRIBUTE).get();
                    if (cnxn != null) {
                        cnxn.processQueuedBuffer();
                    }
                    ctx.channel().config().setAutoRead(true);
                } else if (evt == NettyServerCnxn.AutoReadEvent.DISABLE) {
                    LOG.debug("Received AutoReadEvent.DISABLE");
                    ctx.channel().config().setAutoRead(false);
                }
            }
            finally {
                ReferenceCountUtil.release(evt);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            try {
                if (LOG.isTraceEnabled()) {
                    LOG.trace("message received called {}", msg);
                }
                try {
                    NettyServerCnxn cnxn;
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("New message {} from {}", msg, (Object)ctx.channel());
                    }
                    if ((cnxn = (NettyServerCnxn)ctx.channel().attr(CONNECTION_ATTRIBUTE).get()) == null) {
                        LOG.error("channelRead() on a closed or closing NettyServerCnxn");
                    } else {
                        cnxn.processMessage((ByteBuf)msg);
                    }
                }
                catch (Exception ex) {
                    LOG.error("Unexpected exception in receive", ex);
                    throw ex;
                }
            }
            finally {
                ReferenceCountUtil.release(msg);
            }
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            if (LOG.isTraceEnabled()) {
                promise.addListener((GenericFutureListener<? extends Future<? super Void>>)this.onWriteCompletedTracer);
            }
            super.write(ctx, msg, promise);
        }
    }

    class DualModeSslHandler
    extends OptionalSslHandler {
        DualModeSslHandler(SslContext sslContext) {
            super(sslContext);
        }

        @Override
        protected void decode(ChannelHandlerContext context, ByteBuf in, List<Object> out) throws Exception {
            if (in.readableBytes() >= 5) {
                super.decode(context, in, out);
            } else if (in.readableBytes() > 0 && 22 != in.getByte(0)) {
                LOG.debug("first byte {} does not match TLS handshake, failing to plaintext", (Object)in.getByte(0));
                this.handleNonSsl(context);
            }
        }

        private void handleNonSsl(ChannelHandlerContext context) {
            ChannelHandler handler = this.newNonSslHandler(context);
            if (handler != null) {
                context.pipeline().replace(this, this.newNonSslHandlerName(), handler);
            } else {
                context.pipeline().remove(this);
            }
        }

        @Override
        protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) {
            NettyServerCnxn cnxn = (NettyServerCnxn)Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get());
            LOG.debug("creating ssl handler for session {}", (Object)cnxn.getSessionId());
            SslHandler handler = super.newSslHandler(context, sslContext);
            Future<Channel> handshakeFuture = handler.handshakeFuture();
            handshakeFuture.addListener(new CertificateVerifier(handler, cnxn));
            return handler;
        }

        @Override
        protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
            NettyServerCnxn cnxn = (NettyServerCnxn)Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get());
            LOG.debug("creating plaintext handler for session {}", (Object)cnxn.getSessionId());
            NettyServerCnxnFactory.this.allChannels.add(context.channel());
            NettyServerCnxnFactory.this.addCnxn(cnxn);
            return super.newNonSslHandler(context);
        }
    }
}

