mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix conflict
This commit is contained in:
parent
12a7b05360
commit
389e8b700e
|
@ -22,22 +22,13 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
|
|||
print("Compiling network")
|
||||
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
||||
print("Processing texts...")
|
||||
train_X1 = get_word_ids(list(nlp.pipe(train_texts1, n_threads=10, batch_size=10000)),
|
||||
max_length=shape[0],
|
||||
tree_truncate=settings['tree_truncate'])
|
||||
train_X2 = get_word_ids(list(nlp.pipe(train_texts2, n_threads=10, batch_size=10000)),
|
||||
max_length=shape[0],
|
||||
tree_truncate=settings['tree_truncate'])
|
||||
dev_X1 = get_word_ids(list(nlp.pipe(dev_texts1, n_threads=10, batch_size=10000)),
|
||||
max_length=shape[0],
|
||||
tree_truncate=settings['tree_truncate'])
|
||||
dev_X2 = get_word_ids(list(nlp.pipe(dev_texts2, n_threads=10, batch_size=10000)),
|
||||
max_length=shape[0],
|
||||
tree_truncate=settings['tree_truncate'])
|
||||
|
||||
print(train_X1.shape, train_X2.shape)
|
||||
print(dev_X1.shape, dev_X2.shape)
|
||||
print(train_labels.shape, dev_labels.shape)
|
||||
Xs = []
|
||||
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
|
||||
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
|
||||
max_length=shape[0],
|
||||
rnn_encode=settings['gru_encode'],
|
||||
tree_truncate=settings['tree_truncate']))
|
||||
train_X1, train_X2, dev_X1, dev_X2 = Xs
|
||||
print(settings)
|
||||
model.fit(
|
||||
[train_X1, train_X2],
|
||||
|
@ -103,7 +94,7 @@ def read_snli(path):
|
|||
dropout=("Dropout level", "option", "d", float),
|
||||
learn_rate=("Learning rate", "option", "e", float),
|
||||
batch_size=("Batch size for neural network training", "option", "b", float),
|
||||
nr_epoch=("Number of training epochs", "option", "i", float),
|
||||
nr_epoch=("Number of training epochs", "option", "i", int),
|
||||
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
||||
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
||||
)
|
||||
|
|
|
@ -107,8 +107,6 @@ class _Attention(object):
|
|||
def _outer(AB):
|
||||
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
|
||||
return K.permute_dimensions(att_ji,(0, 2, 1))
|
||||
|
||||
|
||||
return merge(
|
||||
[self.model(sent1), self.model(sent2)],
|
||||
mode=_outer,
|
||||
|
@ -153,6 +151,7 @@ class _Comparison(object):
|
|||
def __call__(self, sent, align, **kwargs):
|
||||
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
|
||||
result = _GlobalSumPooling1D()(result, mask=self.words)
|
||||
result = BatchNormalization()(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -40,16 +40,19 @@ def get_embeddings(vocab):
|
|||
return vectors
|
||||
|
||||
|
||||
def get_word_ids(docs, tree_truncate=False, max_length=100):
|
||||
def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100):
|
||||
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
|
||||
for i, doc in enumerate(docs):
|
||||
j = 0
|
||||
queue = [sent.root for sent in doc.sents]
|
||||
if tree_truncate:
|
||||
queue = [sent.root for sent in doc.sents]
|
||||
else:
|
||||
queue = list(doc)
|
||||
words = []
|
||||
while len(words) <= max_length and queue:
|
||||
word = queue.pop(0)
|
||||
if word.has_vector and not word.is_punct and not word.is_space:
|
||||
if rnn_encode or (word.has_vector and not word.is_punct and not word.is_space):
|
||||
words.append(word)
|
||||
if tree_truncate:
|
||||
queue.extend(list(word.lefts))
|
||||
queue.extend(list(word.rights))
|
||||
words.sort()
|
||||
|
|
Loading…
Reference in New Issue
Block a user