Skip to content

Commit 741e260

Browse files
adamlerersoumith
authored andcommitted
Initial checkin
1 parent 2ff9485 commit 741e260

File tree

13 files changed

+1090
-0
lines changed

13 files changed

+1090
-0
lines changed

OpenNMT/LICENSE.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
The MIT License (MIT)
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in
11+
all copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
THE SOFTWARE.

OpenNMT/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# OpenNMT: Open-Source Neural Machine Translation
2+
3+
This is a [Pytorch](https://github.com/pytorch/pytorch)
4+
port of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
5+
an open-source (MIT) neural machine translation system.
6+
7+
<center style="padding: 40px"><img width="70%" src="http://opennmt.github.io/simple-attn.png" /></center>
8+
9+
## Quickstart
10+
11+
OpenNMT consists of three commands:
12+
13+
1) Preprocess the data.
14+
15+
```python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo```
16+
17+
2) Train the model.
18+
19+
```python train.py -data data/demo-train.pt -save_model model -cuda```
20+
21+
3) Translate sentences.
22+
23+
TODO

OpenNMT/onmt/Constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
PAD = 0
3+
UNK = 1
4+
BOS = 2
5+
EOS = 3
6+
7+
PAD_WORD = '<blank>'
8+
UNK_WORD = '<unk>'
9+
BOS_WORD = '<s>'
10+
EOS_WORD = '</s>'

OpenNMT/onmt/Dataset.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import onmt
2+
from torch.autograd import Variable
3+
4+
5+
class Dataset(object):
6+
# FIXME: randomize
7+
def __init__(self, srcData, tgtData, batchSize, cuda):
8+
self.src = srcData['words']
9+
self.tgt = tgtData['words']
10+
self.cuda = cuda
11+
# FIXME
12+
# self.srcFeatures = srcData.features
13+
# self.tgtFeatures = tgtData.features
14+
assert(len(self.src) == len(self.tgt))
15+
self.batchSize = batchSize
16+
self.numBatches = len(self.src) // batchSize
17+
18+
def _batchify(self, data, align_right=False):
19+
max_length = max(x.size(0) for x in data)
20+
out = data[0].new(len(data), max_length).fill_(onmt.Constants.PAD)
21+
for i in range(len(data)):
22+
data_length = data[i].size(0)
23+
offset = max_length - data_length if align_right else 0
24+
out[i].narrow(0, offset, data_length).copy_(data[i])
25+
return Variable(out)
26+
27+
def __getitem__(self, index):
28+
assert index < self.numBatches, "%d > %d" % (index, self.numBatches)
29+
srcBatch = self._batchify(
30+
self.src[index*self.batchSize:(index+1)*self.batchSize], align_right=True)
31+
tgtBatch = self._batchify(
32+
self.tgt[index*self.batchSize:(index+1)*self.batchSize])
33+
34+
if self.cuda:
35+
srcBatch = srcBatch.cuda()
36+
tgtBatch = tgtBatch.cuda()
37+
38+
# FIXME
39+
srcBatch = srcBatch.t().contiguous()
40+
tgtBatch = tgtBatch.t().contiguous()
41+
42+
return srcBatch, tgtBatch
43+
44+
def __len__(self):
45+
return self.numBatches

OpenNMT/onmt/Dict.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import torch
2+
3+
4+
class Dict(object):
5+
def __init__(self, data=None):
6+
self.idxToLabel = {}
7+
self.labelToIdx = {}
8+
self.frequencies = {}
9+
10+
# Special entries will not be pruned.
11+
self.special = []
12+
13+
if data is not None:
14+
if type(data) == str:
15+
self.loadFile(data)
16+
else:
17+
self.addSpecials(data)
18+
19+
def size(self):
20+
return len(self.idxToLabel)
21+
22+
# Load entries from a file.
23+
def loadFile(self, filename):
24+
for line in open(filename):
25+
fields = line.split()
26+
label = fields[0]
27+
idx = int(fields[1])
28+
self.add(label, idx)
29+
30+
# Write entries to a file.
31+
def writeFile(self, filename):
32+
with open(filename, 'w') as file:
33+
for i in range(self.size()):
34+
label = self.idxToLabel[i]
35+
file.write('%s %d\n' % (label, i))
36+
37+
file.close()
38+
39+
def lookup(self, key, default=None):
40+
try:
41+
return self.labelToIdx[key]
42+
except KeyError:
43+
return default
44+
45+
def getLabel(self, idx, default=None):
46+
try:
47+
return self.idxToLabel[idx]
48+
except KeyError:
49+
return default
50+
51+
# Mark this `label` and `idx` as special (i.e. will not be pruned).
52+
def addSpecial(self, label, idx=None):
53+
idx = self.add(label, idx)
54+
self.special += [idx]
55+
56+
# Mark all labels in `labels` as specials (i.e. will not be pruned).
57+
def addSpecials(self, labels):
58+
for label in labels:
59+
self.addSpecial(label)
60+
61+
# Add `label` in the dictionary. Use `idx` as its index if given.
62+
def add(self, label, idx=None):
63+
if idx is not None:
64+
self.idxToLabel[idx] = label
65+
self.labelToIdx[label] = idx
66+
else:
67+
if label in self.labelToIdx:
68+
idx = self.labelToIdx[label]
69+
else:
70+
idx = len(self.idxToLabel)
71+
self.idxToLabel[idx] = label
72+
self.labelToIdx[label] = idx
73+
74+
if idx not in self.frequencies:
75+
self.frequencies[idx] = 1
76+
else:
77+
self.frequencies[idx] += 1
78+
79+
return idx
80+
81+
# Return a new dictionary with the `size` most frequent entries.
82+
def prune(self, size):
83+
if size >= self.size():
84+
return self
85+
86+
# Only keep the `size` most frequent entries.
87+
freq = torch.Tensor(
88+
[self.frequencies[i] for i in range(len(self.frequencies))])
89+
_, idx = torch.sort(freq, 0, True)
90+
91+
newDict = Dict()
92+
93+
# Add special entries in all cases.
94+
for i in self.special:
95+
newDict.addSpecial(self.idxToLabel[i])
96+
97+
for i in idx[:size]:
98+
newDict.add(self.idxToLabel[i])
99+
100+
return newDict
101+
102+
# Convert `labels` to indices. Use `unkWord` if not found.
103+
# Optionally insert `bosWord` at the beginning and `eosWord` at the .
104+
def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None):
105+
vec = []
106+
107+
if bosWord is not None:
108+
vec += [self.lookup(bosWord)]
109+
110+
unk = self.lookup(unkWord)
111+
vec += [self.lookup(label, default=unk) for label in labels]
112+
113+
if eosWord is not None:
114+
vec += [self.lookup(eosWord)]
115+
116+
return torch.LongTensor(vec)
117+
118+
# Convert `idx` to labels. If index `stop` is reached, convert it and return.
119+
def convertToLabels(self, idx, stop):
120+
labels = []
121+
122+
for i in idx:
123+
labels += [self.getLabel(i)]
124+
if i == stop:
125+
break
126+
127+
return labels

0 commit comments

Comments
 (0)