mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
The doc.retokenize() context manager wasn't resizing doc.tensor, leading to a mismatch between the number of tokens in the doc and the number of rows in the tensor. We fix this by deleting rows from the tensor. Merged spans are represented by the vector of their last token. * Add test for resizing doc.tensor when merging * Add test for resizing doc.tensor when merging. Closes #1963 * Update get_lca_matrix test for develop * Fix retokenize if tensor unset
This commit is contained in:
parent
3d64eb4a74
commit
72e4d3782a
|
@ -247,6 +247,16 @@ def test_issue1945():
|
||||||
assert matches[1][1:] == (1, 3)
|
assert matches[1][1:] == (1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue1963(en_tokenizer):
|
||||||
|
"""Test that doc.merge() resizes doc.tensor"""
|
||||||
|
doc = en_tokenizer('a b c d')
|
||||||
|
doc.tensor = numpy.ones((len(doc), 128), dtype='f')
|
||||||
|
with doc.retokenize() as retokenizer:
|
||||||
|
retokenizer.merge(doc[0:2])
|
||||||
|
assert len(doc) == 3
|
||||||
|
assert doc.tensor.shape == (3, 128)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("label", ["U-JOB-NAME"])
|
@pytest.mark.parametrize("label", ["U-JOB-NAME"])
|
||||||
def test_issue1967(label):
|
def test_issue1967(label):
|
||||||
ner = EntityRecognizer(Vocab())
|
ner = EntityRecognizer(Vocab())
|
||||||
|
|
|
@ -7,7 +7,9 @@ from __future__ import unicode_literals
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memcpy, memset
|
||||||
from libc.stdlib cimport malloc, free
|
from libc.stdlib cimport malloc, free
|
||||||
|
|
||||||
|
import numpy
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
from .doc cimport Doc, set_children_from_heads, token_by_start, token_by_end
|
from .doc cimport Doc, set_children_from_heads, token_by_start, token_by_end
|
||||||
from .span cimport Span
|
from .span cimport Span
|
||||||
|
@ -83,6 +85,11 @@ def _merge(Doc doc, int start, int end, attributes):
|
||||||
cdef Span span = doc[start:end]
|
cdef Span span = doc[start:end]
|
||||||
cdef int start_char = span.start_char
|
cdef int start_char = span.start_char
|
||||||
cdef int end_char = span.end_char
|
cdef int end_char = span.end_char
|
||||||
|
# Resize the doc.tensor, if it's set. Let the last row for each token stand
|
||||||
|
# for the merged region. To do this, we create a boolean array indicating
|
||||||
|
# whether the row is to be deleted, then use numpy.delete
|
||||||
|
if doc.tensor is not None and doc.tensor.size != 0:
|
||||||
|
doc.tensor = _resize_tensor(doc.tensor, [(start, end)])
|
||||||
# Get LexemeC for newly merged token
|
# Get LexemeC for newly merged token
|
||||||
new_orth = ''.join([t.text_with_ws for t in span])
|
new_orth = ''.join([t.text_with_ws for t in span])
|
||||||
if span[-1].whitespace_:
|
if span[-1].whitespace_:
|
||||||
|
@ -182,7 +189,12 @@ def _bulk_merge(Doc doc, merges):
|
||||||
else:
|
else:
|
||||||
Token.set_struct_attr(token, attr_name, attr_value)
|
Token.set_struct_attr(token, attr_name, attr_value)
|
||||||
|
|
||||||
|
# Resize the doc.tensor, if it's set. Let the last row for each token stand
|
||||||
|
# for the merged region. To do this, we create a boolean array indicating
|
||||||
|
# whether the row is to be deleted, then use numpy.delete
|
||||||
|
if doc.tensor is not None and doc.tensor.size != 0:
|
||||||
|
doc.tensor = _resize_tensor(doc.tensor,
|
||||||
|
[(m[1][0].start, m[1][0].end) for m in merges])
|
||||||
# Memorize span roots and sets dependencies of the newly merged
|
# Memorize span roots and sets dependencies of the newly merged
|
||||||
# tokens to the dependencies of their roots.
|
# tokens to the dependencies of their roots.
|
||||||
span_roots = []
|
span_roots = []
|
||||||
|
@ -276,6 +288,14 @@ def _bulk_merge(Doc doc, merges):
|
||||||
else:
|
else:
|
||||||
# If they're not the same entity type, let them be two entities
|
# If they're not the same entity type, let them be two entities
|
||||||
doc.c[token_after_span_position].ent_iob = 3
|
doc.c[token_after_span_position].ent_iob = 3
|
||||||
|
|
||||||
# Return the merged Python object
|
# Return the merged Python object
|
||||||
return doc[spans[0].start]
|
return doc[spans[0].start]
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_tensor(tensor, ranges):
|
||||||
|
delete = []
|
||||||
|
for start, end in ranges:
|
||||||
|
for i in range(start, end-1):
|
||||||
|
delete.append(i)
|
||||||
|
xp = get_array_module(tensor)
|
||||||
|
return xp.delete(tensor, delete, axis=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user