MENU

TRIE(4)

July 8, 2018 • Read: 3028 • 算法阅读设置

例3 题目链接:hihoCoder1289


这道题的大意是我们有一个网站,然后要配置规则,决定哪些IP能访问,哪些IP不能。这些规则大概长这个样子:

allow 1.2.3.4/30
deny 1.1.1.1
allow 127.0.0.1
allow 123.234.12.23/3
deny 0.0.0.0/0

;allow是允许访问,deny是不允许访问。后面这个1.2.3.4/30表示的是一个IP段,也就是允许或者不允许的IP范围。具体是哪个范围呢?我们先看IP1.2.3.4,我们知道一个IP由ABCD4段组成,中间用点分开。每一段的范围都是0~255,可以用8个二进制位表示,所以整个IP可以由32个二进制位表示。比如128.127.4.100就对应这个32位的二进制串:10000000011111110000010001100100

而如果一个IP,前20位与128.127.4.100的前20位相同,那么这个IP就在128.127.4.100/20的范围里。具体来说,就是10000000011111110000000000000000到 10000000011111110000111111111111

题目给出了N条规则,然后询问M个IP是允许访问还是拒绝访问。每个IP的判断规则是按顺序依次对比每一条规则,按第一条匹配上的规则处理。所谓匹配上,就是询问的IP在规则的范围里。如果N条规则都没匹配上,就按允许处理。这道题的背景还是挺真实的,如果你自己配过apache或者nginx什么的,你会发现配置文件确实是这么个格式的

首先这道题我们可以暴力做。就是对每一个询问,按顺序匹配每一条规则,一旦匹配上就按规则办事。由于最坏情况下,需要匹配所有N条规则,所以这样整个程序的时间复杂度是O(NM)的,大概只能通过40%的数据

要通过所有的数据我们就要用到Trie。对于一条规则,比如128.127.4.100/20,我们就把128.127.4.100的二进制01串的前20位插入到Trie中。而我们处理询问一个IP的时候,就在Trie中查找这个IP,看看沿途经过哪些终结点。显然经过的每个终结点都对应的是一条能匹配上的规则。但是我们要找的是这些规则中最早的一条。所以我们给终结点加一个序号,表示是第几条规则,然后在沿途的终结点中找序号最小的就可以了


例如在上图中,我们在Trie里找IP 000110101101…的时候,经过了10号和20号终结点,那么这两个终结点对应的规则哪一条靠前,哪一条就是这个IP匹配上的规则,最后按匹配上的规则输出结果就好了

#include <cstdio>
#include <cstring>
using namespace std;
const int MAX_NODE = 3300000 + 10;
const int CHARSET = 2;
int trie[MAX_NODE][CHARSET] = {0};
int color[MAX_NODE];
int n,m,ans,k = 1,b[32];
char action[10],s[100];
struct Rule {
    long long ip;
    int digit;
    bool allow;
} rules[100000];
Rule resolve(const char* address,bool allow = true) {
    Rule ret;
    ret.ip = 0;
    ret.allow = allow;
    const char* ptr = address;
    int cur_sect = 0;
    long long token = 0;
    while(*ptr != 0 && *ptr != '/') {
        if(*ptr == '.') {
            ret.ip |= token << ((3 - cur_sect) * 8);
            cur_sect++;
            token = 0; 
        }
        else 
            token = token * 10 + (*ptr - '0');
        ptr++;
    }
    ret.ip |= token << ((3 - cur_sect) * 8);
    int digits = 0;
    if(*ptr == 0) 
        digits = 32;
    else {
        ptr++;
        while(*ptr != 0) {
            digits = digits * 10 + (*ptr - '0');
            ptr++;
        }
    }
    ret.digit = digits;
    return ret;
}
void update(int p,int r) {
    if(color[p] == -1)
        color[p] = r;
}
void insert(long long x,int d,bool allow,int r) {
    if(d == 0) {
        update(0,r);
        return;
    }
    for(int i = 31;i >= 0;i--) {
        b[i] = x & 1;
        x >>= 1;
    }
    int p = 0;
    for(int i = 0;i < d;i++) {
        int c = b[i];
        if(!trie[p][c]) {
            trie[p][c] = k;
            k++;
        }
        p = trie[p][c];
    }
    update(p,r);
}
int search(long long x) {
    int r = color[0];
    for(int i = 31;i >= 0;i--) {
        b[i] = x & 1;
        x >>= 1; 
    }
    int p = 0;
    for(int i = 0;i < 32;i++) {
        int c = b[i];
        if(!trie[p][c])
            break;
        p = trie[p][c];
        if(color[p] != -1 && (r == -1 || r > color[p]))
            r = color[p];
    } 
    return r;
}
int main() {
    memset(color,-1,sizeof(color));
    scanf("%d%d",&n,&m);
    for(int i = 0;i < n;i++) {
        scanf("%s%s",action,s);
        rules[i] = resolve(s,action[0] == 'a');
        insert(rules[i].ip,rules[i].digit,rules[i].allow,i);
    }
    while(m--) {
        scanf("%s",s);
        Rule r = resolve(s);
        ans = search(r.ip);
        if(ans == -1 || rules[ans].allow)
            puts("YES");
        else
            puts("NO");
    }
    return 0;
}

首先我们看一下用到的变量,和main函数的整体逻辑。这道题我们要存储的一个IP对应的01串,所以注意第5行,字符集的大小被设置成了2。第10~14行的Rule结构数组是用来保存所有规则。我们用一个整数ip,也就是32位二进制代表的数值;digit指子网掩码的位数;allow表示是允许还是拒绝。特别需要注意一下color数组,如果一个节点i不是终结点,color[i]==-1;否则color[i]保存的是这个终结点对应的规则序号

Main函数的逻辑是90~94行处理每一个规则,首先把输入的字符串解析成Rule结构的格式,通过resolve()这个函数。然后再把解析出来的ip插入到trie中。第91~103行是在处理每一个询问,拿到一个字符串ip首先也是解析成一个整数ip。然后我们在trie中查找这个整数(代表的二进制串)。Search函数会返回路径上所有终结点对应的规则的最小序号

insert函数

首先update是用来更新节点p上的规则序号,参数r是新的规则序号。因为main函数中我们是按顺序插入的规则,而题目要求我们按最早匹配的规则行事,所以p也应该保留较早的规则。所以update是除非现在p上没有规则(color[p]=-1),才会令color[p]=r

Insert函数插入一个规则,规则的ip是x,掩码位数是d,允许与否是allow,序号是r。第51~54的特判,如果位数d=0,就直接更新在根节点0上的规则。第55~58是求出x的二进制串,b[0]是最高位,b[31]是最低位。第59~67是按位插入0/1,注意我们只插入前d位。最后在终结点p更新规则

search函数

X也是一个整数ip。首先用r记录根节点的规则序号。然后72~75行也是在计算x的01串。76~84行是在trie上进行搜索。注意第82行,这一行的逻辑是如果p是终结点,并且序号比当前的序号r更小,那么就更新当前序号。这样r就一直是保持沿途终结点中最小的序号。最后返回r给主函数

resolve函数

最后我们看一下resolve函数,这个函数就是字符串处理,把128.127.4.100/20这样的字符串分解+计算出整数ip和整数掩码位数

Last Modified: May 12, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment