mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 17:33:10 +03:00
Merge pull request #625 from pspiegelhalter/master
PR to fix Issue #623
This commit is contained in:
commit
2ee66117ba
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||||
from spacy_hook import get_embeddings, get_word_ids
|
from spacy_hook import get_embeddings, get_word_ids
|
||||||
from spacy_hook import create_similarity_pipeline
|
from spacy_hook import create_similarity_pipeline
|
||||||
|
|
||||||
|
from keras_decomposable_attention import build_model
|
||||||
|
|
||||||
def train(model_dir, train_loc, dev_loc, shape, settings):
|
def train(model_dir, train_loc, dev_loc, shape, settings):
|
||||||
print("Loading spaCy")
|
print("Loading spaCy")
|
||||||
|
|
|
@ -101,10 +101,11 @@ class _Attention(object):
|
||||||
self.model = TimeDistributed(self.model)
|
self.model = TimeDistributed(self.model)
|
||||||
|
|
||||||
def __call__(self, sent1, sent2):
|
def __call__(self, sent1, sent2):
|
||||||
def _outer((A, B)):
|
def _outer(AB):
|
||||||
att_ji = T.batched_dot(B, A.dimshuffle((0, 2, 1)))
|
att_ji = T.batched_dot(AB[1], AB[0].dimshuffle((0, 2, 1)))
|
||||||
return att_ji.dimshuffle((0, 2, 1))
|
return att_ji.dimshuffle((0, 2, 1))
|
||||||
|
|
||||||
|
|
||||||
return merge(
|
return merge(
|
||||||
[self.model(sent1), self.model(sent2)],
|
[self.model(sent1), self.model(sent2)],
|
||||||
mode=_outer,
|
mode=_outer,
|
||||||
|
@ -117,7 +118,9 @@ class _SoftAlignment(object):
|
||||||
self.nr_hidden = nr_hidden
|
self.nr_hidden = nr_hidden
|
||||||
|
|
||||||
def __call__(self, sentence, attention, transpose=False):
|
def __call__(self, sentence, attention, transpose=False):
|
||||||
def _normalize_attention((att, mat)):
|
def _normalize_attention(attmat):
|
||||||
|
att = attmat[0]
|
||||||
|
mat = attmat[1]
|
||||||
if transpose:
|
if transpose:
|
||||||
att = att.dimshuffle((0, 2, 1))
|
att = att.dimshuffle((0, 2, 1))
|
||||||
# 3d softmax
|
# 3d softmax
|
||||||
|
|
Loading…
Reference in New Issue
Block a user