Fix regression test for #615 and remove unnecessary imports

This commit is contained in:
Ines Montani 2017-01-12 16:49:40 +01:00
parent aeb747e10c
commit 51ef75f629

View File

@ -1,35 +1,34 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import spacy from ...matcher import Matcher
from spacy.attrs import ORTH from ...attrs import ORTH
from ..util import get_doc
def merge_phrases(matcher, doc, i, matches): def test_issue615(en_tokenizer):
''' def merge_phrases(matcher, doc, i, matches):
Merge a phrase. We have to be careful here because we'll change the token indices. """Merge a phrase. We have to be careful here because we'll change the
To avoid problems, merge all the phrases once we're called on the last match. token indices. To avoid problems, merge all the phrases once we're called
''' on the last match."""
if i != len(matches)-1:
return None
# Get Span objects
spans = [(ent_id, label, doc[start : end]) for ent_id, label, start, end in matches]
for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
def test_entity_ID_assignment(): if i != len(matches)-1:
nlp = spacy.en.English() return None
text = """The golf club is broken""" # Get Span objects
doc = nlp(text) spans = [(ent_id, label, doc[start : end]) for ent_id, label, start, end in matches]
for ent_id, label, span in spans:
span.merge('NNP' if label else span.root.tag_, span.text, doc.vocab.strings[label])
golf_pattern = [ text = "The golf club is broken"
{ ORTH: "golf"}, pattern = [{ ORTH: "golf"}, { ORTH: "club"}]
{ ORTH: "club"} label = "Sport_Equipment"
]
matcher = spacy.matcher.Matcher(nlp.vocab) tokens = en_tokenizer(text)
matcher.add_entity('Sport_Equipment', on_match = merge_phrases) doc = get_doc(tokens.vocab, [t.text for t in tokens])
matcher.add_pattern("Sport_Equipment", golf_pattern, label = 'Sport_Equipment')
matcher = Matcher(doc.vocab)
matcher.add_entity(label, on_match=merge_phrases)
matcher.add_pattern(label, pattern, label=label)
match = matcher(doc) match = matcher(doc)
entities = list(doc.ents) entities = list(doc.ents)