This commit is contained in:
Andreas Grivas 2015-11-04 18:35:54 +00:00
commit 44fef28d92
5 changed files with 105 additions and 19 deletions

View File

@ -163,7 +163,7 @@ def run_setup(exts):
'spacy.tests.munge',
'spacy.tests.parser',
'spacy.tests.serialize',
'spacy.tests.span',
'spacy.tests.spans',
'spacy.tests.tagger',
'spacy.tests.tokenizer',
'spacy.tests.tokens',

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
import pytest
@pytest.mark.models
def test_merge_tokens(EN):
tokens = EN(u'Los Angeles start.')
assert len(tokens) == 4
@ -13,7 +12,6 @@ def test_merge_tokens(EN):
assert tokens[0].head.orth_ == 'start'
@pytest.mark.models
def test_merge_heads(EN):
tokens = EN(u'I found a pilates class near work.')
assert len(tokens) == 8
@ -32,7 +30,6 @@ def test_issue_54(EN):
text = u'Talks given by women had a slightly higher number of questions asked (3.2$\pm$0.2) than talks given by men (2.6$\pm$0.1).'
tokens = EN(text)
@pytest.mark.models
def test_np_merges(EN):
text = u'displaCy is a parse tool built with Javascript'
tokens = EN(text)
@ -47,3 +44,27 @@ def test_np_merges(EN):
merged = tokens.merge(start, end, label, lemma, label)
assert merged != None, (start, end, label, lemma)
def test_entity_merge(EN):
tokens = EN(u'Stewart Lee is a stand up comedian who lives in England and loves Joe Pasquale')
assert(len(tokens) == 15)
for ent in tokens.ents:
label, lemma, type_ = (ent.root.tag_, ent.root.lemma_, max(w.ent_type_ for w in ent))
ent.merge(label, lemma, type_)
# check looping is ok
assert(len(tokens) == 13)
def test_sentence_update_after_merge(EN):
tokens = EN(u'Stewart Lee is a stand up comedian. He lives in England and loves Joe Pasquale')
sent1, sent2 = list(tokens.sents)
init_len = len(sent1)
merge_me = tokens[0:2]
merge_me.merge(u'none', u'none', u'none')
assert(len(sent1) == init_len - 1)
def test_subtree_size_check(EN):
tokens = EN(u'Stewart Lee is a stand up comedian who lives in England and loves Joe Pasquale')
sent1 = list(tokens.sents)[0]
init_len = len(list(sent1.root.subtree))
merge_me = tokens[0:2]
merge_me.merge(u'none', u'none', u'none')
assert(len(list(sent1.root.subtree)) == init_len - 1)

View File

@ -438,11 +438,26 @@ cdef class Doc:
keep_reading = False
yield n_bytes_str + data
# This function is terrible --- need to fix this.
def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma,
unicode ent_type):
"""Merge a multi-word expression into a single token. Currently
experimental; API is likely to change."""
def token_index_start(self, int start_idx):
""" Get index of token in doc that has character index start_idx """
cdef int i
for i in range(self.length):
if self.c[i].idx == start_idx:
return i
return None
def token_index_end(self, int end_idx):
""" Get index+1 of token in doc ending with character index end_idx """
cdef int i
for i in range(self.length):
if (self.c[i].idx + self.c[i].lex.length) == end_idx:
return i + 1
return None
def range_from_indices(self, int start_idx, int end_idx):
""" Get tuple - span of token indices which correspond to
character indices (start_idx, end_idx) if such a span exists"""
cdef int i
cdef int start = -1
cdef int end = -1
@ -453,10 +468,18 @@ cdef class Doc:
if start == -1:
return None
end = i + 1
break
else:
return (start, end)
return None
# This function is terrible --- need to fix this.
def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma,
unicode ent_type):
"""Merge a multi-word expression into a single token. Currently
experimental; API is likely to change."""
start_end = self.range_from_indices(start_idx, end_idx)
if start_end is None:
return None
start, end = start_end
cdef Span span = self[start:end]
# Get LexemeC for newly merged token
new_orth = ''.join([t.text_with_ws for t in span])
@ -465,8 +488,6 @@ cdef class Doc:
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
# House the new merged token where it starts
cdef TokenC* token = &self.c[start]
# Update fields
token.lex = lex
token.spacy = self.c[end-1].spacy
if tag in self.vocab.morphology.tag_map:
self.vocab.morphology.assign_tag(token, tag)
@ -485,6 +506,10 @@ cdef class Doc:
# bridges over the entity. Here the alignment of the tokens changes.
span_root = span.root.i
token.dep = span.root.dep
# We update token.lex after keeping span root and dep, since
# setting token.lex will change span.start and span.end properties
# as it modifies the character offsets in the doc
token.lex = lex
for i in range(self.length):
self.c[i].head += i
# Set the head of the merged token, and its dep relation, from the Span

View File

@ -4,8 +4,10 @@ from .doc cimport Doc
cdef class Span:
cdef readonly Doc doc
cdef public int i
cdef public int start
cdef public int end
cdef public int start_token
cdef public int end_token
cdef public int start_idx
cdef public int end_idx
cdef readonly int label
cdef public _vector

View File

@ -14,15 +14,19 @@ from ..util import normalize_slice
cdef class Span:
"""A slice from a Doc object."""
"""A slice from a Doc object. Internally keeps character offsets in order
to keep track of changes (merges) in the original Doc. Updates are
made in start and end property."""
def __cinit__(self, Doc tokens, int start, int end, int label=0, vector=None,
vector_norm=None):
if not (0 <= start <= end <= len(tokens)):
raise IndexError
self.doc = tokens
self.start = start
self.end = end
self.start_token = start
self.start_idx = self.doc[start].idx
self.end_token = end
self.end_idx = self.doc[end - 1].idx + len(self.doc[end - 1])
self.label = label
self._vector = vector
self._vector_norm = vector_norm
@ -76,6 +80,40 @@ cdef class Span:
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
property start:
""" Get start token index of this span from the Doc."""
def __get__(self):
# if we have merged spans in Doc start might have changed.
# check if token start index is in doc index range and the token
# index is start_idx (it hasn't changed).
# Potential IndexError if only second condition was used
if self.start_token >= len(self.doc) or self.doc[self.start_token].idx != self.start_idx:
# go through tokens in Doc - find index of token equal to start_idx
new_start = self.doc.token_index_start(self.start_idx)
if new_start is not None:
self.start_token = new_start
else:
raise IndexError('Something went terribly wrong during a merge.'
'No token found with idx %s' % self.start_idx)
return self.start_token
property end:
""" Get end token index of this span from the Doc."""
def __get__(self):
# if we have merged spans in Doc end will have changed.
# check if token end index is in doc index range and the token
# index is end_idx (it hasn't changed).
# Potential IndexError if only second condition was used
if self.end_token >= len(self.doc) or self.doc[self.end_token - 1].idx != self.end_idx:
# go through tokens in Doc - find index of token equal to end_idx
new_end = self.doc.token_index_end(self.end_idx)
if new_end is not None:
self.end_token = new_end
else:
raise IndexError('Something went terribly wrong during a merge.'
'No token found with idx %s' % self.end_idx)
return self.end_token
property vector:
def __get__(self):
if self._vector is None: