mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix conflict
This commit is contained in:
parent
12a7b05360
commit
389e8b700e
|
@ -22,22 +22,13 @@ def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||||
print("Compiling network")
|
print("Compiling network")
|
||||||
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
model = build_model(get_embeddings(nlp.vocab), shape, settings)
|
||||||
print("Processing texts...")
|
print("Processing texts...")
|
||||||
train_X1 = get_word_ids(list(nlp.pipe(train_texts1, n_threads=10, batch_size=10000)),
|
Xs = []
|
||||||
max_length=shape[0],
|
for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2):
|
||||||
tree_truncate=settings['tree_truncate'])
|
Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)),
|
||||||
train_X2 = get_word_ids(list(nlp.pipe(train_texts2, n_threads=10, batch_size=10000)),
|
max_length=shape[0],
|
||||||
max_length=shape[0],
|
rnn_encode=settings['gru_encode'],
|
||||||
tree_truncate=settings['tree_truncate'])
|
tree_truncate=settings['tree_truncate']))
|
||||||
dev_X1 = get_word_ids(list(nlp.pipe(dev_texts1, n_threads=10, batch_size=10000)),
|
train_X1, train_X2, dev_X1, dev_X2 = Xs
|
||||||
max_length=shape[0],
|
|
||||||
tree_truncate=settings['tree_truncate'])
|
|
||||||
dev_X2 = get_word_ids(list(nlp.pipe(dev_texts2, n_threads=10, batch_size=10000)),
|
|
||||||
max_length=shape[0],
|
|
||||||
tree_truncate=settings['tree_truncate'])
|
|
||||||
|
|
||||||
print(train_X1.shape, train_X2.shape)
|
|
||||||
print(dev_X1.shape, dev_X2.shape)
|
|
||||||
print(train_labels.shape, dev_labels.shape)
|
|
||||||
print(settings)
|
print(settings)
|
||||||
model.fit(
|
model.fit(
|
||||||
[train_X1, train_X2],
|
[train_X1, train_X2],
|
||||||
|
@ -103,7 +94,7 @@ def read_snli(path):
|
||||||
dropout=("Dropout level", "option", "d", float),
|
dropout=("Dropout level", "option", "d", float),
|
||||||
learn_rate=("Learning rate", "option", "e", float),
|
learn_rate=("Learning rate", "option", "e", float),
|
||||||
batch_size=("Batch size for neural network training", "option", "b", float),
|
batch_size=("Batch size for neural network training", "option", "b", float),
|
||||||
nr_epoch=("Number of training epochs", "option", "i", float),
|
nr_epoch=("Number of training epochs", "option", "i", int),
|
||||||
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool),
|
||||||
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool),
|
||||||
)
|
)
|
||||||
|
|
|
@ -107,8 +107,6 @@ class _Attention(object):
|
||||||
def _outer(AB):
|
def _outer(AB):
|
||||||
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
|
att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1)))
|
||||||
return K.permute_dimensions(att_ji,(0, 2, 1))
|
return K.permute_dimensions(att_ji,(0, 2, 1))
|
||||||
|
|
||||||
|
|
||||||
return merge(
|
return merge(
|
||||||
[self.model(sent1), self.model(sent2)],
|
[self.model(sent1), self.model(sent2)],
|
||||||
mode=_outer,
|
mode=_outer,
|
||||||
|
@ -153,6 +151,7 @@ class _Comparison(object):
|
||||||
def __call__(self, sent, align, **kwargs):
|
def __call__(self, sent, align, **kwargs):
|
||||||
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
|
result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n)
|
||||||
result = _GlobalSumPooling1D()(result, mask=self.words)
|
result = _GlobalSumPooling1D()(result, mask=self.words)
|
||||||
|
result = BatchNormalization()(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,16 +40,19 @@ def get_embeddings(vocab):
|
||||||
return vectors
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
def get_word_ids(docs, tree_truncate=False, max_length=100):
|
def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100):
|
||||||
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
|
Xs = numpy.zeros((len(docs), max_length), dtype='int32')
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
j = 0
|
if tree_truncate:
|
||||||
queue = [sent.root for sent in doc.sents]
|
queue = [sent.root for sent in doc.sents]
|
||||||
|
else:
|
||||||
|
queue = list(doc)
|
||||||
words = []
|
words = []
|
||||||
while len(words) <= max_length and queue:
|
while len(words) <= max_length and queue:
|
||||||
word = queue.pop(0)
|
word = queue.pop(0)
|
||||||
if word.has_vector and not word.is_punct and not word.is_space:
|
if rnn_encode or (word.has_vector and not word.is_punct and not word.is_space):
|
||||||
words.append(word)
|
words.append(word)
|
||||||
|
if tree_truncate:
|
||||||
queue.extend(list(word.lefts))
|
queue.extend(list(word.lefts))
|
||||||
queue.extend(list(word.rights))
|
queue.extend(list(word.rights))
|
||||||
words.sort()
|
words.sort()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user