mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
set vector of merged entity (#5085)
* merge_entities sets the vector in the vocab for the merged token * add unit test * import unicode_literals * move code to _merge function * only set vector if vocab has non-zero vectors
This commit is contained in:
parent
3440a72ecb
commit
1a2b8fc264
46
spacy/tests/regression/test_issue5082.py
Normal file
46
spacy/tests/regression/test_issue5082.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.pipeline import EntityRuler
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue5082():
|
||||||
|
# Ensure the 'merge_entities' pipeline does something sensible for the vectors of the merged tokens
|
||||||
|
nlp = English()
|
||||||
|
vocab = nlp.vocab
|
||||||
|
array1 = np.asarray([0.1, 0.5, 0.8], dtype=np.float32)
|
||||||
|
array2 = np.asarray([-0.2, -0.6, -0.9], dtype=np.float32)
|
||||||
|
array3 = np.asarray([0.3, -0.1, 0.7], dtype=np.float32)
|
||||||
|
array4 = np.asarray([0.5, 0, 0.3], dtype=np.float32)
|
||||||
|
array34 = np.asarray([0.4, -0.05, 0.5], dtype=np.float32)
|
||||||
|
|
||||||
|
vocab.set_vector("I", array1)
|
||||||
|
vocab.set_vector("like", array2)
|
||||||
|
vocab.set_vector("David", array3)
|
||||||
|
vocab.set_vector("Bowie", array4)
|
||||||
|
|
||||||
|
text = "I like David Bowie"
|
||||||
|
ruler = EntityRuler(nlp)
|
||||||
|
patterns = [
|
||||||
|
{"label": "PERSON", "pattern": [{"LOWER": "david"}, {"LOWER": "bowie"}]}
|
||||||
|
]
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
|
parsed_vectors_1 = [t.vector for t in nlp(text)]
|
||||||
|
assert len(parsed_vectors_1) == 4
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_1[0], array1)
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_1[1], array2)
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_1[2], array3)
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_1[3], array4)
|
||||||
|
|
||||||
|
merge_ents = nlp.create_pipe("merge_entities")
|
||||||
|
nlp.add_pipe(merge_ents)
|
||||||
|
|
||||||
|
parsed_vectors_2 = [t.vector for t in nlp(text)]
|
||||||
|
assert len(parsed_vectors_2) == 3
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_2[0], array1)
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_2[1], array2)
|
||||||
|
np.testing.assert_array_equal(parsed_vectors_2[2], array34)
|
|
@ -213,6 +213,10 @@ def _merge(Doc doc, merges):
|
||||||
new_orth = ''.join([t.text_with_ws for t in spans[token_index]])
|
new_orth = ''.join([t.text_with_ws for t in spans[token_index]])
|
||||||
if spans[token_index][-1].whitespace_:
|
if spans[token_index][-1].whitespace_:
|
||||||
new_orth = new_orth[:-len(spans[token_index][-1].whitespace_)]
|
new_orth = new_orth[:-len(spans[token_index][-1].whitespace_)]
|
||||||
|
# add the vector of the (merged) entity to the vocab
|
||||||
|
if not doc.vocab.get_vector(new_orth).any():
|
||||||
|
if doc.vocab.vectors_length > 0:
|
||||||
|
doc.vocab.set_vector(new_orth, span.vector)
|
||||||
token = tokens[token_index]
|
token = tokens[token_index]
|
||||||
lex = doc.vocab.get(doc.mem, new_orth)
|
lex = doc.vocab.get(doc.mem, new_orth)
|
||||||
token.lex = lex
|
token.lex = lex
|
||||||
|
|
Loading…
Reference in New Issue
Block a user