Skip to content

Commit 5d54d43

Browse files
authored
Merge pull request #29 from brunobastosg/different-alphas
One alpha parameter per rule, which allows disabling certain rules Thanks @brunobastosg for the PR, LGTM
2 parents 14e43c6 + 2fe06d7 commit 5d54d43

File tree

3 files changed

+51
-26
lines changed

3 files changed

+51
-26
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ Now place this input file into the `data` folder. Run
5656
python code/augment.py --input=<insert input filename>
5757
```
5858

59-
The default output filename will append `eda_` to the front of the input filename, but you can specify your own with `--output`. You can also specify the number of generated augmented sentences per original sentence using `--num_aug` (default is 9). Furthermore, you can specify the alpha parameter, which approximately means the percent of words in the sentence that will be changed (default is `0.1` or `10%`). So for example, if your input file is `sst2_train.txt` and you want to output to `sst2_augmented.txt` with `16` augmented sentences per original sentence and `alpha=0.05`, you would do:
59+
The default output filename will append `eda_` to the front of the input filename, but you can specify your own with `--output`. You can also specify the number of generated augmented sentences per original sentence using `--num_aug` (default is 9). Furthermore, you can specify different alpha parameters, which approximately means the percent of words in the sentence that will be changed according to that rule (default is `0.1` or `10%`). So for example, if your input file is `sst2_train.txt` and you want to output to `sst2_augmented.txt` with `16` augmented sentences per original sentence and replace 5% of words by synonyms (`alpha_sr=0.05`), delete 10% of words (`alpha_rd=0.1`, or leave as the default) and do not apply random insertion (`alpha_ri=0.0`) and random swap (`alpha_rs=0.0`), you would do:
6060

6161
```bash
62-
python code/augment.py --input=sst2_train.txt --output=sst2_augmented.txt --num_aug=16 --alpha=0.05
62+
python code/augment.py --input=sst2_train.txt --output=sst2_augmented.txt --num_aug=16 --alpha_sr=0.05 --alpha_rd=0.1 --alpha_ri=0.0 --alpha_rs=0.0
6363
```
6464

65-
Note that at least one augmentation operation is applied per augmented sentence regardless of alpha. So if you do `alpha=0.001` and your sentence only has four words, one augmentation operation will still be performed. Best of luck!
65+
Note that at least one augmentation operation is applied per augmented sentence regardless of alpha (if greater than zero). So if you do `alpha_sr=0.001` and your sentence only has four words, one augmentation operation will still be performed. Of course, if one particular alpha is zero, nothing will be done. Best of luck!
6666

6767
# Citation
6868
If you use EDA in your paper, please cite us:

code/augment.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
ap.add_argument("--input", required=True, type=str, help="input file of unaugmented data")
1010
ap.add_argument("--output", required=False, type=str, help="output file of unaugmented data")
1111
ap.add_argument("--num_aug", required=False, type=int, help="number of augmented sentences per original sentence")
12-
ap.add_argument("--alpha", required=False, type=float, help="percent of words in each sentence to be changed")
12+
ap.add_argument("--alpha_sr", required=False, type=float, help="percent of words in each sentence to be replaced by synonyms")
13+
ap.add_argument("--alpha_ri", required=False, type=float, help="percent of words in each sentence to be inserted")
14+
ap.add_argument("--alpha_rs", required=False, type=float, help="percent of words in each sentence to be swapped")
15+
ap.add_argument("--alpha_rd", required=False, type=float, help="percent of words in each sentence to be deleted")
1316
args = ap.parse_args()
1417

1518
#the output file
@@ -25,13 +28,31 @@
2528
if args.num_aug:
2629
num_aug = args.num_aug
2730

28-
#how much to change each sentence
29-
alpha = 0.1#default
30-
if args.alpha:
31-
alpha = args.alpha
31+
#how much to replace each word by synonyms
32+
alpha_sr = 0.1#default
33+
if args.alpha_sr is not None:
34+
alpha_sr = args.alpha_sr
35+
36+
#how much to insert new words that are synonyms
37+
alpha_ri = 0.1#default
38+
if args.alpha_ri is not None:
39+
alpha_ri = args.alpha_ri
40+
41+
#how much to swap words
42+
alpha_rs = 0.1#default
43+
if args.alpha_rs is not None:
44+
alpha_rs = args.alpha_rs
45+
46+
#how much to delete words
47+
alpha_rd = 0.1#default
48+
if args.alpha_rd is not None:
49+
alpha_rd = args.alpha_rd
50+
51+
if alpha_sr == alpha_ri == alpha_rs == alpha_rd == 0:
52+
ap.error('At least one alpha should be greater than zero')
3253

3354
#generate more data with standard augmentation
34-
def gen_eda(train_orig, output_file, alpha, num_aug=9):
55+
def gen_eda(train_orig, output_file, alpha_sr, alpha_ri, alpha_rs, alpha_rd, num_aug=9):
3556

3657
writer = open(output_file, 'w')
3758
lines = open(train_orig, 'r').readlines()
@@ -40,7 +61,7 @@ def gen_eda(train_orig, output_file, alpha, num_aug=9):
4061
parts = line[:-1].split('\t')
4162
label = parts[0]
4263
sentence = parts[1]
43-
aug_sentences = eda(sentence, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug)
64+
aug_sentences = eda(sentence, alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=alpha_rd, num_aug=num_aug)
4465
for aug_sentence in aug_sentences:
4566
writer.write(label + "\t" + aug_sentence + '\n')
4667

@@ -51,4 +72,4 @@ def gen_eda(train_orig, output_file, alpha, num_aug=9):
5172
if __name__ == "__main__":
5273

5374
#generate augmented sentences and output into a new file
54-
gen_eda(args.input, output, alpha=alpha, num_aug=num_aug)
75+
gen_eda(args.input, output, alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, alpha_rd=alpha_rd, num_aug=num_aug)

code/eda.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,29 +179,33 @@ def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9)
179179

180180
augmented_sentences = []
181181
num_new_per_technique = int(num_aug/4)+1
182-
n_sr = max(1, int(alpha_sr*num_words))
183-
n_ri = max(1, int(alpha_ri*num_words))
184-
n_rs = max(1, int(alpha_rs*num_words))
185182

186183
#sr
187-
for _ in range(num_new_per_technique):
188-
a_words = synonym_replacement(words, n_sr)
189-
augmented_sentences.append(' '.join(a_words))
184+
if (alpha_sr > 0):
185+
n_sr = max(1, int(alpha_sr*num_words))
186+
for _ in range(num_new_per_technique):
187+
a_words = synonym_replacement(words, n_sr)
188+
augmented_sentences.append(' '.join(a_words))
190189

191190
#ri
192-
for _ in range(num_new_per_technique):
193-
a_words = random_insertion(words, n_ri)
194-
augmented_sentences.append(' '.join(a_words))
191+
if (alpha_ri > 0):
192+
n_ri = max(1, int(alpha_ri*num_words))
193+
for _ in range(num_new_per_technique):
194+
a_words = random_insertion(words, n_ri)
195+
augmented_sentences.append(' '.join(a_words))
195196

196197
#rs
197-
for _ in range(num_new_per_technique):
198-
a_words = random_swap(words, n_rs)
199-
augmented_sentences.append(' '.join(a_words))
198+
if (alpha_rs > 0):
199+
n_rs = max(1, int(alpha_rs*num_words))
200+
for _ in range(num_new_per_technique):
201+
a_words = random_swap(words, n_rs)
202+
augmented_sentences.append(' '.join(a_words))
200203

201204
#rd
202-
for _ in range(num_new_per_technique):
203-
a_words = random_deletion(words, p_rd)
204-
augmented_sentences.append(' '.join(a_words))
205+
if (p_rd > 0):
206+
for _ in range(num_new_per_technique):
207+
a_words = random_deletion(words, p_rd)
208+
augmented_sentences.append(' '.join(a_words))
205209

206210
augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
207211
shuffle(augmented_sentences)

0 commit comments

Comments
 (0)