hdu1251(字符串Trie)

Trie基础题

数组写法:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 1e6 + 10;
int trie[maxn][28], num[maxn];
//bool tail[maxn][28];
int sz;

void init()
{
    sz = 1;
    memset(trie[0], 0, sizeof(trie[0]));        //手误,原来sizeof(trie)了......
}


int idx(char ch)
{
    return ch - 'a';
}


void insert(char *s)
{
    int u = 0, len = strlen(s);
    for(int i = 0; i < len; i++)
    {
        int t = idx(s[i]);
        if(!trie[u][t])
        {
            memset(trie[sz], 0, sizeof(trie[sz]));
            trie[u][t] = sz++;
        }
        u = trie[u][t];
        num[u]++;
    }
}


int query(char *s)
{
    int u = 0, len = strlen(s);
    int sum = 0;

    for(int i = 0; i < len; i++)
    {
        int t = idx(s[i]);
        if(!trie[u][t])   return 0;
        u = trie[u][t];
    }
    sum = num[u];
    return sum;
}



int main()
{
    char word[15];
    init();
    while(gets(word) && word[0] != NULL)            //此题巩固了gets用法:不保存'\n',将其转换为'\0'(NULL),而且另一题中在gets前如有回车需要加getchar()!不然会多出一行空行
    {
        insert(word);
    }

    while(scanf("%s", word) == 1)
    {
        int ans = query(word);
        printf("%d\n", ans);
    }
    return 0;
}

链表版本:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 3e3 + 10;
const int sz = 28;

struct trie
{
    int num;
    bool tail;
    struct trie *next[sz];
};

trie *root;

trie *init()
{
    trie *p = (trie*)malloc(sizeof(trie));
    p->num = 0;
    p->tail = 0;
    for(int i = 0; i < sz; i++)
        p->next[i] = NULL;

    return p;
}


int idx(char ch)
{
    return ch - 'a';
}


void insert(char *s)
{
    trie *p = root;
    int len = strlen(s);
    for(int i = 0; i < len; i++)
    {
        int t = idx(s[i]);
        if(p->next[t] == NULL)
        {
            p->next[t] = init();
        }
        p = p->next[t];
        p->num++;
    }
    p->tail = 1;
}


int query(char *s)
{
    trie *p = root;
    int len = strlen(s);
    for(int i = 0; i < len; i++)
    {
        int t = idx(s[i]);
        if(p->next[t] == NULL)  return 0;
        p = p->next[t];
    }
    return p->num;
}

void del(trie *p)
{
    for(int i = 0; i < sz; i++)
    {
        if(p->next[i])
            del(p->next[i]);
    }
    del(p);
}



int main()
{
    char word[15];
    root = init();
    while(gets(word) && word[0] != NULL)            //
    {
        insert(word);
    }

    while(scanf("%s", word) == 1)
    {
        int ans = query(word);
        printf("%d\n", ans);
    }
    del(root);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_38577732/article/details/89497504