Skip to content

Commit e41730c

Browse files
committed
export Config to config.py; common.SpecialDictWords(Enum);
1 parent bee8ccb commit e41730c

File tree

2 files changed

+80
-70
lines changed

2 files changed

+80
-70
lines changed

common.py

Lines changed: 9 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,17 @@
22
import json
33
import sys
44
from enum import Enum
5-
from math import ceil
6-
7-
8-
class Config:
9-
@staticmethod
10-
def get_default_config(args):
11-
config = Config()
12-
config.DL_FRAMEWORK = 'keras'
13-
config.NUM_EPOCHS = 20
14-
config.SAVE_EVERY_EPOCHS = 1
15-
config.BATCH_SIZE = 1024
16-
config.TEST_BATCH_SIZE = config.BATCH_SIZE
17-
config.READING_BATCH_SIZE = 1300 * 4
18-
config.NUM_BATCHING_THREADS = 2
19-
config.BATCH_QUEUE_SIZE = 300000
20-
config.MAX_CONTEXTS = 200
21-
config.WORDS_VOCAB_SIZE = 1301136
22-
config.TARGET_VOCAB_SIZE = 261245
23-
config.PATHS_VOCAB_SIZE = 911417
24-
config.EMBEDDINGS_SIZE = 128
25-
config.MAX_TO_KEEP = 10
26-
config.DROPOUT_KEEP_RATE = 0.75
27-
28-
config.READER_NUM_PARALLEL_BATCHES = 1 # cpu cores [for tf.contrib.data.map_and_batch()]
29-
config.SHUFFLE_BUFFER_SIZE = 10000
30-
config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
31-
32-
# Automatically filled, do not edit:
33-
config.TRAIN_PATH = args.data_path
34-
config.TEST_PATH = args.test_path
35-
config.SAVE_PATH = args.save_path
36-
config.LOAD_PATH = args.load_path
37-
config.RELEASE = args.release
38-
config.EXPORT_CODE_VECTORS = args.export_code_vectors
39-
return config
40-
41-
def __init__(self):
42-
self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
43-
self.NUM_EPOCHS: int = 0
44-
self.SAVE_EVERY_EPOCHS: int = 0
45-
self.BATCH_SIZE: int = 0
46-
self.TEST_BATCH_SIZE: int = 0
47-
self.READING_BATCH_SIZE: int = 0
48-
self.NUM_BATCHING_THREADS: int = 0
49-
self.BATCH_QUEUE_SIZE: int = 0
50-
self.MAX_CONTEXTS: int = 0
51-
self.WORDS_VOCAB_SIZE: int = 0
52-
self.TARGET_VOCAB_SIZE: int = 0
53-
self.PATHS_VOCAB_SIZE: int = 0
54-
self.EMBEDDINGS_SIZE: int = 0
55-
self.MAX_TO_KEEP: int = 0
56-
self.DROPOUT_KEEP_RATE: float = 0
57-
58-
self.READER_NUM_PARALLEL_BATCHES: int = 0
59-
self.SHUFFLE_BUFFER_SIZE: int = 0
60-
self.CSV_BUFFER_SIZE: int = 0
61-
62-
self.SAVE_PATH: str = ''
63-
self.LOAD_PATH: str = ''
64-
self.TRAIN_PATH: str = ''
65-
self.TEST_PATH: str = ''
66-
self.RELEASE: bool = False
67-
self.EXPORT_CODE_VECTORS: bool = False
68-
69-
@property
70-
def steps_per_epoch(self) -> int:
71-
return ceil(self.NUM_EXAMPLES / self.BATCH_SIZE)
5+
import tensorflow as tf
726

737

748
class common:
75-
noSuchWord = "NoSuchWord"
769

10+
class SpecialDictWords(Enum):
11+
NoSuchWord = 0
12+
13+
@classmethod
14+
def index_to_start_dict_from(cls):
15+
return 1 + max(special_word.value for special_word in cls)
7716

7817
@staticmethod
7918
def normalize_word(word):
@@ -209,7 +148,7 @@ def split_to_batches(data_lines, batch_size):
209148

210149
@staticmethod
211150
def legal_method_names_checker(name):
212-
return name != common.noSuchWord and re.match('^[a-zA-Z\|]+$', name)
151+
return name != common.SpecialDictWords.NoSuchWord.name and re.match('^[a-zA-Z\|]+$', name)
213152

214153
@staticmethod
215154
def filter_impossible_names(top_words):
@@ -227,7 +166,7 @@ def parse_results(result, unhash_dict, topk=5):
227166
original_name, top_suggestions, top_scores, attention_per_context = list(single_method)
228167
current_method_prediction_results = PredictionResults(original_name)
229168
for i, predicted in enumerate(top_suggestions):
230-
if predicted == common.noSuchWord:
169+
if predicted == common.SpecialDictWords.NoSuchWord.name:
231170
continue
232171
suggestion_subtokens = common.get_subtokens(predicted)
233172
current_method_prediction_results.append_prediction(suggestion_subtokens, top_scores[i].item())

config.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from math import ceil
2+
3+
4+
class Config:
5+
@staticmethod
6+
def get_default_config(args):
7+
config = Config()
8+
config.DL_FRAMEWORK = 'keras'
9+
config.NUM_EPOCHS = 20
10+
config.SAVE_EVERY_EPOCHS = 1
11+
config.BATCH_SIZE = 1024
12+
config.TEST_BATCH_SIZE = config.BATCH_SIZE
13+
config.READING_BATCH_SIZE = 1300 * 4
14+
config.NUM_BATCHING_THREADS = 2
15+
config.BATCH_QUEUE_SIZE = 300000
16+
config.MAX_CONTEXTS = 200
17+
config.WORDS_VOCAB_SIZE = 1301136
18+
config.TARGET_VOCAB_SIZE = 261245
19+
config.PATHS_VOCAB_SIZE = 911417
20+
config.EMBEDDINGS_SIZE = 128
21+
config.MAX_TO_KEEP = 10
22+
config.DROPOUT_KEEP_RATE = 0.75
23+
24+
config.READER_NUM_PARALLEL_BATCHES = 1 # cpu cores [for tf.contrib.data.map_and_batch()]
25+
config.SHUFFLE_BUFFER_SIZE = 10000
26+
config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
27+
28+
# Automatically filled, do not edit:
29+
config.TRAIN_PATH = args.data_path
30+
config.TEST_PATH = args.test_path
31+
config.SAVE_PATH = args.save_path
32+
config.LOAD_PATH = args.load_path
33+
config.RELEASE = args.release
34+
config.EXPORT_CODE_VECTORS = args.export_code_vectors
35+
return config
36+
37+
def __init__(self):
38+
self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
39+
self.NUM_EPOCHS: int = 0
40+
self.SAVE_EVERY_EPOCHS: int = 0
41+
self.BATCH_SIZE: int = 0
42+
self.TEST_BATCH_SIZE: int = 0
43+
self.READING_BATCH_SIZE: int = 0
44+
self.NUM_BATCHING_THREADS: int = 0
45+
self.BATCH_QUEUE_SIZE: int = 0
46+
self.MAX_CONTEXTS: int = 0
47+
self.WORDS_VOCAB_SIZE: int = 0
48+
self.TARGET_VOCAB_SIZE: int = 0
49+
self.PATHS_VOCAB_SIZE: int = 0
50+
self.EMBEDDINGS_SIZE: int = 0
51+
self.MAX_TO_KEEP: int = 0
52+
self.DROPOUT_KEEP_RATE: float = 0
53+
54+
self.READER_NUM_PARALLEL_BATCHES: int = 0
55+
self.SHUFFLE_BUFFER_SIZE: int = 0
56+
self.CSV_BUFFER_SIZE: int = 0
57+
58+
# Automatically filled by `args`.
59+
self.SAVE_PATH: str = ''
60+
self.LOAD_PATH: str = ''
61+
self.TRAIN_PATH: str = ''
62+
self.TEST_PATH: str = ''
63+
self.RELEASE: bool = False
64+
self.EXPORT_CODE_VECTORS: bool = False
65+
66+
# Automatically filled by `ModelBase.__init__()`.
67+
self.NUM_EXAMPLES: int = 0
68+
69+
@property
70+
def steps_per_epoch(self) -> int:
71+
return ceil(self.NUM_EXAMPLES / self.BATCH_SIZE)

0 commit comments

Comments
 (0)