RtmpHandshakeHandler.java 11.5 KB
package com.genersoft.iot.vmp.jtt1078.rtmp;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

/**
 * RTMP 握手与推流处理器 (Netty 版本)
 */
public class RtmpHandshakeHandler extends ChannelDuplexHandler {

    private static final Logger logger = LoggerFactory.getLogger(RtmpHandshakeHandler.class);

    private enum State {
        HANDSHAKE_C0C1, HANDSHAKE_C2, CONNECT, CREATE_STREAM, PUBLISH, STREAMING
    }

    // 就绪回调
    private Runnable onReadyListener;

    public void setOnReadyListener(Runnable listener) {
        this.onReadyListener = listener;
    }

    private State state = State.HANDSHAKE_C0C1;
    private final String streamName;
    private final String rtmpUrl;
    private final String app;

    private static final int RTMP_CHUNK_SIZE = 4096;

    private long startTime = 0;
    private boolean isFirstTag = true;

    public RtmpHandshakeHandler(String app, String rtmpUrl, String streamName) {
        this.app = app;
        this.rtmpUrl = rtmpUrl;
        this.streamName = streamName;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        logger.info("[{}] TCP 连接建立成功,开始 RTMP 握手流程...", streamName);

        ByteBuf c0c1 = Unpooled.buffer(1537);
        c0c1.writeByte(0x03);
        c0c1.writeInt((int) (System.currentTimeMillis() / 1000));
        c0c1.writeZero(1532);

        ctx.writeAndFlush(c0c1);
        logger.info("[{}] 已发送 C0+C1 握手包", streamName);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof ByteBuf) {
            ByteBuf buf = (ByteBuf) msg;
            try {
                handleRtmpMessage(ctx, buf);
            } finally {
                buf.release();
            }
        } else {
            super.channelRead(ctx, msg);
        }
    }

    private void handleRtmpMessage(ChannelHandlerContext ctx, ByteBuf msg) {
        switch (state) {
            case HANDSHAKE_C0C1:
                if (msg.readableBytes() >= 1537) {
                    logger.info("[{}] 收到 S0+S1,握手第一阶段完成。", streamName);

                    msg.skipBytes(1); // Skip S0
                    ByteBuf s1 = msg.readBytes(1536);

                    ctx.writeAndFlush(s1.retain());
                    logger.info("[{}] 已发送 C2,握手最后阶段...", streamName);
                    s1.release();

                    sendSetChunkSize(ctx);
                    sendConnect(ctx);

                    state = State.CONNECT;
                    logger.info("[{}] >>> 状态流转: HANDSHAKE -> CONNECT", streamName);
                }
                break;

            case CONNECT:
                logger.info("[{}] 收到 connect 响应。", streamName);
                if (msg.readableBytes() > 20) {
                    sendCreateStream(ctx);
                    state = State.CREATE_STREAM;
                    logger.info("[{}] >>> 状态流转: CONNECT -> CREATE_STREAM", streamName);
                }
                break;

            case CREATE_STREAM:
                logger.info("[{}] 收到 createStream 响应。", streamName);
                if (msg.readableBytes() > 10) {
                    sendPublish(ctx);
                    state = State.PUBLISH;
                    logger.info("[{}] >>> 状态流转: CREATE_STREAM -> PUBLISH (流名: {})", streamName, streamName);
                }
                break;

            case PUBLISH:
                String response = safeReadAscii(msg);
                logger.info("[{}] 收到 publish 响应: {}", streamName, response);
                if (response.contains("NetStream.Publish.Start") || msg.readableBytes() > 10) {
                    state = State.STREAMING;
                    logger.info("[{}] >>> !!! 推流通道已打通 !!!", streamName);

                    // 通知 Client 可以开始发流了
                    if (onReadyListener != null) {
                        onReadyListener.run();
                    }
                }
                break;

            case STREAMING:
                break;
        }
    }

    @Override
    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
        // 如果是 ByteBuf (视频流) 且状态不是 STREAMING,说明握手没完成
        if (msg instanceof ByteBuf) {
            ByteBuf flvTag = (ByteBuf) msg;
            if (state == State.STREAMING) {
                try {
                    wrapAndSendChunk(ctx, flvTag, promise);
                } catch (Exception e) {
                    logger.error("[{}] Chunk发送异常", streamName, e);
                    if (flvTag.refCnt() > 0) flvTag.release();
                }
            } else {
                // 状态未就绪,直接丢弃
                if (flvTag.refCnt() > 0) flvTag.release();
            }
        } else {
            // 其他类型消息正常透传
            super.write(ctx, msg, promise);
        }
    }

    /**
     * 核心修复:强制使用实际 Buffer 长度作为 BodySize
     * 与jtt1078-video-server保持一致:只处理单个 FLV Tag
     */
    private void wrapAndSendChunk(ChannelHandlerContext ctx, ByteBuf flvTag, ChannelPromise promise) {
        if (flvTag.readableBytes() < 11) {
            if (logger.isDebugEnabled()) logger.debug("[{}] FLV Tag 长度不足11字节,丢弃", streamName);
            flvTag.release();
            return;
        }

        // --- 1. 读取 FLV Header ---
        int type = flvTag.readByte();
        flvTag.skipBytes(3); // 【重要】跳过 FLV 中记录的长度,不信任使用实际大小
        int timestamp = flvTag.readMedium();
        int tsEx = flvTag.readByte();
        timestamp |= (tsEx << 24);
        flvTag.skipBytes(3); // 跳过 StreamID

        // --- 2. 重新计算真实的 Body 大小 ---
        int actualBodySize = flvTag.readableBytes();

        if (actualBodySize < 0 || actualBodySize > 0xFFFFFF) {
            logger.error("[{}] 检测到非法的包体大小: {}, 丢弃该包", streamName, actualBodySize);
            flvTag.release();
            return;
        }

        // --- 3. 时间戳相对化处理 ---
        if (isFirstTag) {
            startTime = timestamp;
            isFirstTag = false;
        }
        long rtmpTimestamp = timestamp - startTime;
        if (rtmpTimestamp < 0) rtmpTimestamp = 0;

        // --- 4. 准备 Chunk Header (Type 0) ---
        int csid = (type == 8 || type == 9) ? (type == 8 ? 4 : 6) : 4;

        CompositeByteBuf outBuf = ctx.alloc().compositeBuffer();

        ByteBuf header = ctx.alloc().buffer(12);
        header.writeByte(0x00 | (csid & 0x3F));

        if (rtmpTimestamp >= 0xFFFFFF) {
            header.writeMedium(0xFFFFFF);
        } else {
            header.writeMedium((int) rtmpTimestamp);
        }

        header.writeMedium(actualBodySize);
        header.writeByte(type);
        header.writeIntLE(1);

        if (rtmpTimestamp >= 0xFFFFFF) {
            header.writeInt((int) rtmpTimestamp);
        }

        outBuf.addComponent(true, header);

        // --- 5. 分块发送 ---
        int remaining = actualBodySize;
        int firstChunkLen = Math.min(remaining, RTMP_CHUNK_SIZE);
        if (firstChunkLen > 0) {
            outBuf.addComponent(true, flvTag.readRetainedSlice(firstChunkLen));
            remaining -= firstChunkLen;
        }

        while (remaining > 0) {
            ByteBuf subHeader = ctx.alloc().buffer(1);
            subHeader.writeByte(0xC0 | (csid & 0x3F));
            outBuf.addComponent(true, subHeader);

            int chunkLen = Math.min(remaining, RTMP_CHUNK_SIZE);
            outBuf.addComponent(true, flvTag.readRetainedSlice(chunkLen));
            remaining -= chunkLen;
        }

        flvTag.release();

        if (outBuf.isReadable()) {
            ctx.writeAndFlush(outBuf, promise);
        } else {
            outBuf.release();
        }
    }

    // =========================================================================
    // 命令构建
    // =========================================================================

    private void sendSetChunkSize(ChannelHandlerContext ctx) {
        logger.info("[{}] 发送 SetChunkSize 命令: {}", streamName, RTMP_CHUNK_SIZE);
        ByteBuf buf = Unpooled.buffer(16);
        buf.writeByte(0x02);
        buf.writeMedium(0);
        buf.writeMedium(4);
        buf.writeByte(0x01);
        buf.writeIntLE(0);
        buf.writeInt(RTMP_CHUNK_SIZE);
        ctx.writeAndFlush(buf);
    }

    private void sendConnect(ChannelHandlerContext ctx) {
        logger.info("[{}] 发送 connect 命令. App: {}, TcUrl: {}", streamName, app, rtmpUrl);
        ByteBuf buf = Unpooled.buffer();
        Amf0Util.writeString(buf, "connect");
        Amf0Util.writeNumber(buf, 1.0);
        Map<String, Object> params = new HashMap<>();
        params.put("app", app);
        params.put("tcUrl", rtmpUrl);
        params.put("flashVer", "FMLE/3.0 (compatible; FMSc/1.0)");
        params.put("swfUrl", "");
        Amf0Util.writeObject(buf, params);
        writeCommandMessage(ctx, buf, 0);
    }

    private void sendCreateStream(ChannelHandlerContext ctx) {
        logger.info("[{}] 发送 createStream 命令...", streamName);
        ByteBuf buf = Unpooled.buffer();
        Amf0Util.writeString(buf, "createStream");
        Amf0Util.writeNumber(buf, 2.0);
        Amf0Util.writeNull(buf);
        writeCommandMessage(ctx, buf, 0);
    }

    private void sendPublish(ChannelHandlerContext ctx) {
        // 从 rtmpUrl 中提取查询参数(如 ?sign=xxx)
        String publishName = streamName;
        int queryIndex = rtmpUrl.indexOf('?');
        if (queryIndex > 0) {
            String queryParams = rtmpUrl.substring(queryIndex);
            publishName = streamName + queryParams;
            logger.info("[{}] 发送 publish 命令. StreamName: {} (包含鉴权参数)", streamName, publishName);
        } else {
            logger.info("[{}] 发送 publish 命令. StreamName: {}", streamName, publishName);
        }

        ByteBuf buf = Unpooled.buffer();
        Amf0Util.writeString(buf, "publish");
        Amf0Util.writeNumber(buf, 3.0);
        Amf0Util.writeNull(buf);
        Amf0Util.writeString(buf, publishName);
        Amf0Util.writeString(buf, "live");
        writeCommandMessage(ctx, buf, 1);
    }

    private void writeCommandMessage(ChannelHandlerContext ctx, ByteBuf payload, int streamId) {
        int len = payload.readableBytes();
        ByteBuf header = ctx.alloc().buffer(12);
        header.writeByte(0x03);
        header.writeMedium(0);
        header.writeMedium(len);
        header.writeByte(0x14);
        header.writeIntLE(streamId);
        ctx.write(header);
        ctx.writeAndFlush(payload);
    }

    private String safeReadAscii(ByteBuf buf) {
        int len = Math.min(buf.readableBytes(), 100);
        byte[] bytes = new byte[len];
        buf.getBytes(buf.readerIndex(), bytes);
        String raw = new String(bytes, StandardCharsets.UTF_8);
        return raw.replaceAll("[^\\x20-\\x7E]", ".");
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        logger.error("[{}] 通道异常", streamName, cause);
        ctx.close();
    }
}