Skip to content

Commit e1e0673

Browse files
committed
trie in python
1 parent 17bbd62 commit e1e0673

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

python/35_trie/trie_.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/usr/bin/python
2+
# -*- coding: UTF-8 -*-
3+
4+
from queue import Queue
5+
import pygraphviz as pgv
6+
7+
OUTPUT_PATH = 'E:/'
8+
9+
10+
class Node:
11+
def __init__(self, c):
12+
self.data = c
13+
self.is_ending_char = False
14+
# 使用有序数组,降低空间消耗,支持更多字符
15+
self.children = []
16+
17+
def insert_child(self, c):
18+
"""
19+
插入一个子节点
20+
:param c:
21+
:return:
22+
"""
23+
v = ord(c)
24+
idx = self._find_insert_idx(v)
25+
length = len(self.children)
26+
27+
node = Node(c)
28+
if idx == length:
29+
self.children.append(node)
30+
else:
31+
self.children.append(None)
32+
for i in range(length, idx, -1):
33+
self.children[i] = self.children[i-1]
34+
self.children[idx] = node
35+
36+
def get_child(self, c):
37+
"""
38+
搜索子节点并返回
39+
:param c:
40+
:return:
41+
"""
42+
start = 0
43+
end = len(self.children) - 1
44+
v = ord(c)
45+
46+
while start <= end:
47+
mid = (start + end)//2
48+
if v == ord(self.children[mid].data):
49+
return self.children[mid]
50+
elif v < ord(self.children[mid].data):
51+
end = mid - 1
52+
else:
53+
start = mid + 1
54+
# 找不到返回None
55+
return None
56+
57+
def _find_insert_idx(self, v):
58+
"""
59+
二分查找,找到有序数组的插入位置
60+
:param v:
61+
:return:
62+
"""
63+
start = 0
64+
end = len(self.children) - 1
65+
66+
while start <= end:
67+
mid = (start + end)//2
68+
if v < ord(self.children[mid].data):
69+
end = mid - 1
70+
else:
71+
if mid + 1 == len(self.children) or v < ord(self.children[mid+1].data):
72+
return mid + 1
73+
else:
74+
start = mid + 1
75+
# v < self.children[0]
76+
return 0
77+
78+
def __repr__(self):
79+
return 'node value: {}'.format(self.data) + '\n' \
80+
+ 'children:{}'.format([n.data for n in self.children])
81+
82+
83+
class Trie:
84+
def __init__(self):
85+
self.root = Node(None)
86+
87+
def gen_tree(self, string_list):
88+
"""
89+
创建trie树
90+
91+
1. 遍历每个字符串的字符,从根节点开始,如果没有对应子节点,则创建
92+
2. 每一个串的末尾节点标注为红色(is_ending_char)
93+
:param string_list:
94+
:return:
95+
"""
96+
for string in string_list:
97+
n = self.root
98+
for c in string:
99+
if n.get_child(c) is None:
100+
n.insert_child(c)
101+
n = n.get_child(c)
102+
n.is_ending_char = True
103+
104+
def search(self, pattern):
105+
"""
106+
搜索
107+
108+
1. 遍历模式串的字符,从根节点开始搜索,如果途中子节点不存在,返回False
109+
2. 遍历完模式串,则说明模式串存在,再检查树中最后一个节点是否为红色,是
110+
则返回True,否则False
111+
:param pattern:
112+
:return:
113+
"""
114+
assert type(pattern) is str and len(pattern) > 0
115+
116+
n = self.root
117+
for c in pattern:
118+
if n.get_child(c) is None:
119+
return False
120+
n = n.get_child(c)
121+
122+
return True if n.is_ending_char is True else False
123+
124+
def draw_img(self, img_name='Trie.png'):
125+
"""
126+
画出trie树
127+
:param img_name:
128+
:return:
129+
"""
130+
if self.root is None:
131+
return
132+
133+
tree = pgv.AGraph('graph foo {}', strict=False, directed=False)
134+
135+
# root
136+
nid = 0
137+
color = 'black'
138+
tree.add_node(nid, color=color, label='None')
139+
140+
q = Queue()
141+
q.put((self.root, nid))
142+
while not q.empty():
143+
n, pid = q.get()
144+
for c in n.children:
145+
nid += 1
146+
q.put((c, nid))
147+
color = 'red' if c.is_ending_char is True else 'black'
148+
tree.add_node(nid, color=color, label=c.data)
149+
tree.add_edge(pid, nid)
150+
151+
tree.graph_attr['epsilon'] = '0.01'
152+
tree.layout('dot')
153+
tree.draw(OUTPUT_PATH + img_name)
154+
return True
155+
156+
157+
if __name__ == '__main__':
158+
string_list = ['abc', 'abd', 'abcc', 'accd', 'acml', 'P@trick', 'data', 'structure', 'algorithm']
159+
160+
print('--- gen trie ---')
161+
print(string_list)
162+
trie = Trie()
163+
trie.gen_tree(string_list)
164+
# trie.draw_img()
165+
166+
print('\n')
167+
print('--- search result ---')
168+
search_string = ['a', 'ab', 'abc', 'abcc', 'abe', 'P@trick', 'P@tric', 'Patrick']
169+
for ss in search_string:
170+
print('[pattern]: {}'.format(ss), '[result]: {}'.format(trie.search(ss)))

0 commit comments

Comments
 (0)