Skip to content

Commit 9eed746

Browse files
committed
Pruning large domain variables
1 parent 10697da commit 9eed746

File tree

5 files changed

+31
-70
lines changed

5 files changed

+31
-70
lines changed

net/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def save(self):
4343
for key, value in self.data.iteritems()
4444
}
4545
print json.dumps(data, indent=4)
46-
with open(STORE, 'w') as file:
47-
json.dump(data, file, indent=4)
46+
with open(STORE, 'w') as bitch:
47+
json.dump(data, bitch, indent=4)
4848

4949

5050
class MyTCPHandler(BaseRequestHandler):

src/data/__init__.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from functools import partial
33

44

5+
DOMAIN_CAP = 5
6+
7+
58
class DataSet(object):
69
"""
710
A dataset (D) is a set of records (D_1 ... D_N), each of which is an
@@ -53,23 +56,32 @@ def __init__(self, name='scale'):
5356
item.append(variable.process(value))
5457
self.data.append(item)
5558

56-
# Prune null variables
59+
# Prune null variables (limit domains + delete if necessary)
5760
self.variables = [v for v in self.variables if v]
5861
[v.finish(self.data, idx) for idx, v in enumerate(self.variables)]
62+
kill_idxs = list(reversed(sorted([
63+
idx for idx, variable in enumerate(self.variables)
64+
if len(variable.domain) > DOMAIN_CAP
65+
])))
66+
print 'Killed variables', kill_idxs
67+
for idx in kill_idxs:
68+
del self.variables[idx]
5969
self.variable_map = {v.name: v for v in self.variables}
6070

6171
# Order variables and data by domain size
6272
sorts = sorted(self.variables, key=lambda x: len(x.domain))[::-1]
6373
new_data = []
6474
for item in self.data:
75+
for idx in kill_idxs:
76+
del item[idx]
6577
new_item = [None] * len(item)
6678
for idx, variable in enumerate(self.variables):
6779
new_item[sorts.index(variable)] = item[idx]
6880
new_data.append(new_item)
6981
self.data = new_data
7082
self.variables = sorts
71-
# for variable in self.variables:
72-
# print '{}: {}'.format(variable.name, repr(variable.domain))
83+
for variable in self.variables:
84+
print '{}: {}'.format(variable.name, repr(variable.domain))
7385

7486

7587
def is_num(test):
@@ -106,11 +118,9 @@ def process(self, value):
106118
return item
107119

108120
def finish(self, data, idx):
109-
space = 5
110-
if len(self.domain) > space and all([is_num(x) for x in self.domain]):
121+
if len(self.domain) > DOMAIN_CAP and all([is_num(x) for x in self.domain]):
111122
ints = [float(x) for x in self.domain]
112-
# self.var_type = binning(space, min(ints), max(ints))
113-
self.var_type = partial(bitchen_spaces_bro, space, min(ints), max(ints))
123+
self.var_type = partial(bitchen_spaces_bro, DOMAIN_CAP, min(ints), max(ints))
114124
for item in data:
115125
item[idx] = self.var_type(item[idx])
116126
self.domain = set([self.var_type(x) for x in self.domain])

src/data/adult/__init__.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/data/flag/__init__.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/scores.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@ def __call__(self, name):
6464
print 'Expanding Nodes'
6565
self.cap = log(2 * self.N / log(self.N))
6666
self.progress = Bar()
67-
self.progress.set_base(self.count_node_expansions(), True)
67+
expansion_count = self.count_node_expansions()
68+
self.progress.set_base(expansion_count, True)
6869
self.expand_ad_node(-1, set(), set(range(self.N)))
6970

7071
print 'Prune Variables'
71-
self.progress.set_base(
72-
self.count_prune(set()) * len(self.variables), True
73-
)
72+
prune_count = self.count_prune(set()) * len(self.variables)
73+
self.progress.set_base(prune_count, True)
7474
for X in self.variables:
7575
self.prune(X, set(), self.score.get(X, set()))
7676

77-
# Fuck if I know
77+
# clear negative scores
7878
for key, value in self.score.cache.iteritems():
7979
if value < 0:
8080
self.score.cache[key] = 0
@@ -84,10 +84,12 @@ def __call__(self, name):
8484
with open(file_path, 'wb') as f:
8585
pickle.dump(self.score.cache, f)
8686
log_path = path.abspath(path.join(
87-
path.dirname(__file__), 'data', name, 'time.txt'
87+
path.dirname(__file__), 'data', name, 'info.txt'
8888
))
8989
with open(log_path, 'w') as f:
90-
f.write('{}s\n'.format(stop - start))
90+
f.write('Time: {}s\nExpansions: {}\nPrunes: {}'.format(
91+
stop - start, expansion_count, prune_count
92+
))
9193

9294
return self.score
9395

@@ -99,7 +101,7 @@ def count_node_expansions(self, i=-1, depth=0):
99101
self.variables.index(variable), depth + 1
100102
)
101103
size += count * len(variable.domain)
102-
return float(size)
104+
return size
103105

104106
def expand_ad_node(self, i, U, D_u):
105107
"""
@@ -145,12 +147,12 @@ def update_scores(self, U, D_size):
145147
def count_prune(self, U):
146148
U = frozenset(U)
147149
if U in self._cache:
148-
return 1.0
150+
return 1
149151
size = 1
150152
for X in self.vset.difference(U):
151153
size += self.count_prune(U.union({X}))
152154
self._cache.add(U)
153-
return float(size)
155+
return size
154156

155157
_cache2 = set()
156158

0 commit comments

Comments
 (0)