Skip to content

Commit fad82a8

Browse files
committed
train_limit and dev_limit, or split
1 parent 2c21d15 commit fad82a8

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

ml_datasets/loaders/dbpedia.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212

1313
@register_loader("dbpedia")
14-
def dbpedia(loc=None, *, limit=0):
14+
def dbpedia(loc=None, *, train_limit=0, dev_limit=0):
1515
if loc is None:
1616
loc = get_file("dbpedia_csv", DBPEDIA_ONTOLOGY_URL, untar=True, unzip=True)
1717
train_loc = Path(loc) / "train.csv"
1818
test_loc = Path(loc) / "test.csv"
1919
return (
20-
read_dbpedia_ontology(train_loc, limit=limit),
21-
read_dbpedia_ontology(test_loc, limit=limit),
20+
read_dbpedia_ontology(train_loc, limit=train_limit),
21+
read_dbpedia_ontology(test_loc, limit=dev_limit),
2222
)
2323

2424

ml_datasets/loaders/imdb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99

1010
@register_loader("imdb")
11-
def imdb(loc=None, *, limit=0):
11+
def imdb(loc=None, *, train_limit=0, dev_limit=0):
1212
if loc is None:
1313
loc = get_file("aclImdb", IMDB_URL, untar=True, unzip=True)
1414
train_loc = Path(loc) / "train"
1515
test_loc = Path(loc) / "test"
16-
return read_imdb(train_loc, limit=limit), read_imdb(test_loc, limit=limit)
16+
return read_imdb(train_loc, limit=train_limit), read_imdb(test_loc, limit=dev_limit)
1717

1818

1919
def read_imdb(data_dir, *, limit=0):

ml_datasets/spacy_readers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77

88
def cmu_reader(
9-
path: Path = None, *, freq_cutoff: int = 0, limit: int = 0
9+
path: Path = None, *, freq_cutoff: int = 0, limit: int = 0, split=0.9
1010
) -> Dict[str, Callable[["Language"], Iterable["Example"]]]:
1111
from spacy.training.example import Example
1212

13-
# Deduce the categories above threshold by inspecting all training data
14-
all_train_data, _ = list(cmu(path, limit=0))
13+
# Deduce the categories above threshold by inspecting all data
14+
all_train_data, _ = list(cmu(path, limit=0, split=1))
1515
counted_cats = {}
1616
for text, cats in all_train_data:
1717
for cat in cats:
@@ -20,7 +20,7 @@ def cmu_reader(
2020
unique_labels = [
2121
l for l in sorted(counted_cats.keys()) if counted_cats[l] >= freq_cutoff
2222
]
23-
train_data, dev_data = cmu(path, limit=limit, shuffle=False, labels=unique_labels)
23+
train_data, dev_data = cmu(path, limit=limit, shuffle=False, labels=unique_labels, split=split)
2424

2525
def read_examples(data, nlp):
2626
for text, cats in data:
@@ -36,16 +36,16 @@ def read_examples(data, nlp):
3636

3737

3838
def dbpedia_reader(
39-
path: Path = None, *, limit: int = 0
39+
path: Path = None, *, train_limit: int = 0, dev_limit: int = 0
4040
) -> Dict[str, Callable[["Language"], Iterable["Example"]]]:
4141
from spacy.training.example import Example
4242

43-
all_train_data, _ = dbpedia(path, limit=0)
43+
all_train_data, _ = dbpedia(path, train_limit=0, dev_limit=1)
4444
unique_labels = set()
4545
for text, gold_label in all_train_data:
4646
assert isinstance(gold_label, str)
4747
unique_labels.add(gold_label)
48-
train_data, dev_data = dbpedia(path, limit=limit)
48+
train_data, dev_data = dbpedia(path, train_limit=train_limit, dev_limit=dev_limit)
4949

5050
def read_examples(data, nlp):
5151
for text, gold_label in data:
@@ -60,11 +60,11 @@ def read_examples(data, nlp):
6060

6161

6262
def imdb_reader(
63-
path: Path = None, *, limit: int = 0
63+
path: Path = None, *, train_limit: int = 0, dev_limit: int = 0
6464
) -> Dict[str, Callable[["Language"], Iterable["Example"]]]:
6565
from spacy.training.example import Example
6666

67-
train_data, dev_data = imdb(path, limit=limit)
67+
train_data, dev_data = imdb(path, train_limit=train_limit, dev_limit=dev_limit)
6868
unique_labels = ["pos", "neg"]
6969

7070
def read_examples(data, nlp):

0 commit comments

Comments
 (0)