教你怎么模拟一个数据库服务端拦截客户端的语句并且对发送的内容进行修改

教你怎么模拟一个数据库服务端拦截客户端的语句并且对发送的内容进行修改,第1张

教你怎么模拟一个数据库服务端拦截客户端的语句并且对发送的内容进行修改 如何拦截客户端访问数据库的请求并且对发送的请求进行修改处理实现敏感数据模糊化 背景

一年结束了,又是新的一年开始了,针对前段时间研究的内容做一个简单的总结,上次写文章还是在上次,写的内容是基于netty做的一个tcp端口动态代理的的工程,当时做这个工具的目的是为了解决两台服务器之间不能直接通信所以在一台两个服务都能访问的代理服务器上去转发流量的问题,但是这个工程就只做了流量的的转发,并没有对流量进行解析,那么这篇文章就基于上个工程做一个简单的扩展,对转发的流量进行解析,对于我们不希望转发的流量可以进行拦截或修改成我们希望转发的流量。

原理

这里选择的解析流量的内容是基于数据库sql语句的解析,然后对解析的sql进行修改,新的sql语句组装成新的报文发送给数据库服务器,简单来说就是客户端发送的sql语句是(select name,phone from user ) 经过我们代理服务器进行加工后可以让服务器收到新的sql语句为(select name,phone from user limit 10)(这个 *** 作就可以有效防止有人恶意攻击数据库一直查询大量数据导致数据库繁忙),对客户端和服务端都是无感知的从而达到防止攻击的目的,当然这里替换的功能很多,可以做很多事情,其目的并不是只是加一个限制行数而已。

说明

这只是一个验证猜想的实验产品,重心都放在解析流量上了,所以代码就是写流水先实现功能便于排查问题的,优化空间还很大后续再处理。

实现

接下来将对数据库的报文进行解析:

mysql数据包结构组成

代码解析

public class MySqlParser extends DefaultSqlParser {
    //脱敏算法,脱敏后面百分之五十
    String rule = "concat(SUBSTr(#field#,1,CHAR_LENGTH(#field#)/2),substr('*************',CHAR_LENGTH(#field#)/2,CHAR_LENGTH(#field#)/2)) as #field#";
    public static Charset defaultCharset = Charset.forName("gbk");

    @Override
    public void dealChannel(ProxyConfig config, Channel channel, Object msg) {

        ByteBuf readBuffer = (ByteBuf) msg;
        //如果是服务端发送的消息远程地址为空
        InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
        String hostString = remoteAddress.getHostString();
        int port = remoteAddress.getPort();
        //只有发送给数据库的数据才需要进行处理
        int readableBytes = readBuffer.readableBytes();
        if (hostString.equals(config.getRemoteAddr()) && Objects.equals(port, config.getRemotePort()) && readableBytes > 5) {
            //当前的数据流中的数据长度
            //前五位表示请求头
            byte[] preDatas = new byte[5];
            readBuffer.getBytes(0, preDatas);

            if (preDatas[4] == MySQLPacket.COM_QUERY) {
                String localPid = channel.localAddress().toString();
                Charset charset = getCharset(localPid);
                byte[] oldDatas = new byte[readableBytes - 5];
                readBuffer.getBytes(5, oldDatas);
                String sql = new String(oldDatas, Optional.ofNullable(charset).orElse(defaultCharset)).trim();
                //设置此次回话的数据格式
                if (sql.toUpperCase().startsWith("SET NAMES")) {
                    try {
                        putCharset(localPid, Charset.forName(sql.split(" ")[2]));
                    } catch (Exception e) {
                    }
                }
                //如果是查询语句咋替换sql
                if (sql.toUpperCase().trim().startsWith("SELECt")) {
                    dealQuerySql(readBuffer, preDatas, sql);
                }
            }
        }
        channel.writeAndFlush(readBuffer);
    }

    
    public void dealQuerySql(ByteBuf readBuffer, byte[] preDatas, String sql) {
        //避开替换内置的schema表语句
        if (sql.toLowerCase().startsWith("select") && (!sql.toLowerCase().contains("information_schema"))) {
            int select = sql.toLowerCase().indexOf("select");
            int form = sql.indexOf("from");
            String substring = sql.substring(select, form).replace("select", "");
            String[] split = substring.split(",");
            List list = new ArrayList<>();
            for (String column : split) {
                list.add(rule.replace("#field#", column));
            }
            String join = StringUtils.join(list, ",");
            String newSql = "select" + " " + join + " " + sql.substring(form);
            byte[] newSqlBytes = newSql.getBytes();

            System.out.println("newsql=" + newSql);
            PrintUtil.print(newSqlBytes);
            dealHeaderBytes(preDatas, newSqlBytes.length);
            PrintUtil.print(preDatas);

            readBuffer.writerIndex(0);
            readBuffer.writeBytes(preDatas);
            readBuffer.writeBytes(newSqlBytes);
        }
    }

    
    void dealHeaderBytes(byte[] preDatas, int length) {
        //因为最后还有以一位是0结束位必须算进去
        length = length + 1;
        preDatas[0] = (byte) (length & 0xff);
        preDatas[1] = (byte) (length >>> 8);
        preDatas[2] = (byte) (length >>> 16);
    }
效果图

postgrepsql数据包结构组成 来源于dbeaver,jdbc,idea工具连接的数据包结构

来源于navicate,psql工具连接的数据包结构

代码解析

public class PostGrepSqlParser extends DefaultSqlParser {
    String charset = "utf8";
    String rule = "concat(SUBSTr(#field#,1,CHAR_LENGTH(#field#)/2),substr('*************',CHAR_LENGTH(#field#)/2,CHAR_LENGTH(#field#)/2)) as #field#";

    public void dealChannel(ProxyConfig config, Channel channel, Object msg) {
        ByteBuf readBuffer = (ByteBuf) msg;
        int oldByteLength = readBuffer.readableBytes();
        InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
        String hostString = remoteAddress.getHostString();
        int remotePort = remoteAddress.getPort();
        if (Objects.equals(hostString, config.getRemoteAddr()) && Objects.equals(config.getRemotePort(), remotePort) && oldByteLength > 8) {
            //取第一位,如果是80表示从jdbc和idea来的请求,数据复杂一点,如果是81表示从navicat和psql的客户端来的请求,结构稍微简单点
            int startByte = readBuffer.getByte(0);
            switch (startByte) {
                case 80:
                    dealComplex(channel, readBuffer);
                    break;
                case 81:
                    dealSimple(channel, readBuffer);
                    break;
                default:
                    readBuffer.retain();
                    channel.writeAndFlush(readBuffer);
                    break;
            }
        } else {
            channel.writeAndFlush(readBuffer);

        }
    }

    

    public static int getByteLength(byte[] data) {
        int result = 0;
        for (int i = 0; i < data.length; i++) {
            result += (data[i] & 0xff) << ((3 - i) * 8);
        }
        return result;
    }

    
    public static void setHeaderBytes(int length, byte[] data) {
        data[0] = (byte) (length >>> 24);
        data[1] = (byte) (length >>> 16);
        data[2] = (byte) (length >>> 8);
        data[3] = (byte) length;
    }

    
    void dealSimple(Channel channel, ByteBuf readBuffer) {
        int oldByteLength = readBuffer.readableBytes();
        byte headByte = readBuffer.getByte(0);
        byte[] headerBytes = new byte[4];
        readBuffer.getBytes(1, headerBytes);
        //获取长度
        int byteLength = getByteLength(headerBytes);
        //读取数据
        byte[] oldSqlBytes = new byte[byteLength - 5];
        readBuffer.getBytes(5, oldSqlBytes);
        String oldSql = new String(oldSqlBytes);
        readBuffer.retain();
        if (oldSql.toLowerCase().startsWith("select") && (!oldSql.toLowerCase().contains("information_schema"))) {
            String newSql = replaceSql(oldSql);
            byte[] newSqlBytes = newSql.getBytes();
            setHeaderBytes(newSqlBytes.length + 5, headerBytes);
            readBuffer.writerIndex(0);
            readBuffer.writeByte(headByte);
            readBuffer.writeBytes(headerBytes);
            //这种数据包格式的服务端一次只能接收64个字节的包,比较恶心需要分多次发送
            //这里有很大优化空间,重心现在放在解析数据包上暂不处理,后续再优化
            for (int i = 0; i < newSqlBytes.length; i++) {
                readBuffer.writeByte(newSqlBytes[i]);
                int index = readBuffer.writerIndex();
                if (index == 64) {
                    channel.writeAndFlush(readBuffer);
                    readBuffer = Unpooled.buffer(64);
                }

            }
            //注意这里的结束位不能省略
            readBuffer.writeByte(0);
            channel.writeAndFlush(readBuffer);

        } else {
            readBuffer.retain();
            channel.writeAndFlush(readBuffer);
        }

    }

    
    public String replaceSql(String sql) {
        //这里有可能会出现select version等情况的sql,后续再处理,也可能就不会走到这里来先忽略
        try {
            int select = sql.toLowerCase().indexOf("select");
            int form = sql.indexOf("from");
            String substring = sql.substring(select, form);
            String[] split = substring.split(",");
            List list = new ArrayList<>();
            for (String s : split) {
                String select1 = s.replace("select", "");
                list.add(rule.replace("#field#", select1));
            }
            String join = StringUtils.join(list, ",");
            sql = "select" + " " + join + " " + sql.substring(form);
        } catch (Exception e) {
            System.out.println("错误sql:" + sql);
        }
        return sql;

    }

    
    void dealComplex(Channel channel, ByteBuf readBuffer) {
        int oldByteLength = readBuffer.readableBytes();

        byte headByte = readBuffer.getByte(0);
        byte[] headerBytes = new byte[4];
        readBuffer.getBytes(1, headerBytes);
        //获取长度
        int byteLength = getByteLength(headerBytes);
        //读取数据
        byte[] oldSqlBytes = new byte[byteLength - 8];
        readBuffer.getBytes(6, oldSqlBytes);
        String oldSql = new String(oldSqlBytes);
        byte[] endBytes = new byte[oldByteLength - byteLength + 8 - 6];
        readBuffer.getBytes(byteLength - 8 + 6, endBytes);
        readBuffer.retain();
        if (oldSql.toLowerCase().contains("select") && (!oldSql.toLowerCase().contains("information_schema"))) {
            String newSql = replaceSql(oldSql);
            byte[] newSqlBytes = newSql.getBytes();
            setHeaderBytes(newSqlBytes.length + 8, headerBytes);
            readBuffer.writerIndex(0);
            readBuffer.writeByte(headByte);
            readBuffer.writeBytes(headerBytes);
            readBuffer.writeByte(0);
            readBuffer.writeBytes(newSqlBytes);
            readBuffer.writeBytes(endBytes);
        }
        channel.writeAndFlush(readBuffer);
    }
}

效果图

说明

由于oracle、sqlserver和gbase。。。等其他数据库是非开源数据库,开源易被举报,这里不便展示他们的报文格式,只能展示两种开源数据库的报文格式,有兴趣的同学可以自行研究。

温馨提示:

sql不支持*的写法,只能写具体字段,因为这只是一个研究原理用的实验产品,很多细节还没处理,bug也还有很多,只给大家提供一个研究思路。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://www.54852.com/zaji/5697304.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-12-17
下一篇2022-12-17

发表评论

登录后才能评论

评论列表(0条)

    保存