Improve API for doc.merge() and span.merge(), to use keyword arguments.

This commit is contained in:
Matthew Honnibal 2016-10-17 14:02:13 +02:00
parent fbb7f3f15c
commit b67697a97b
2 changed files with 21 additions and 5 deletions

View File

@ -593,9 +593,22 @@ cdef class Doc:
keep_reading = False keep_reading = False
yield n_bytes_str + data yield n_bytes_str + data
def merge(self, int start_idx, int end_idx, unicode tag, unicode lemma, def merge(self, int start_idx, int end_idx, *args, **attributes):
unicode ent_type):
"""Merge a multi-word expression into a single token.""" """Merge a multi-word expression into a single token."""
cdef unicode tag, lemma, ent_type
if len(args) == 3:
# TODO: Warn deprecation
tag, lemma, ent_type = args
attributes[TAG] = self.strings[tag]
attributes[LEMMA] = self.strings[lemma]
attributes[ENT_TYPE] = self.strings[ent_type]
elif args:
raise ValueError(
"Doc.merge received %d non-keyword arguments. "
"Expected either 3 arguments (deprecated), or 0 (use keyword arguments). "
"Arguments supplied:\n%s\n"
"Keyword arguments:%s\n" % (len(args), repr(args), repr(attributes)))
cdef int start = token_by_start(self.c, self.length, start_idx) cdef int start = token_by_start(self.c, self.length, start_idx)
if start == -1: if start == -1:
return None return None
@ -604,8 +617,11 @@ cdef class Doc:
return None return None
# Currently we have the token index, we want the range-end index # Currently we have the token index, we want the range-end index
end += 1 end += 1
cdef Span span = self[start:end] cdef Span span = self[start:end]
tag = self.strings[attributes.get(TAG, span.root.tag)]
lemma = self.strings[attributes.get(LEMMA, span.root.lemma)]
ent_type = self.strings[attributes.get(ENT_TYPE, span.root.ent_type)]
# Get LexemeC for newly merged token # Get LexemeC for newly merged token
new_orth = ''.join([t.text_with_ws for t in span]) new_orth = ''.join([t.text_with_ws for t in span])
if span[-1].whitespace_: if span[-1].whitespace_:

View File

@ -77,8 +77,8 @@ cdef class Span:
for i in range(self.start, self.end): for i in range(self.start, self.end):
yield self.doc[i] yield self.doc[i]
def merge(self, unicode tag, unicode lemma, unicode ent_type): def merge(self, *args, **attributes):
self.doc.merge(self.start_char, self.end_char, tag, lemma, ent_type) self.doc.merge(self.start_char, self.end_char, *args, **attributes)
def similarity(self, other): def similarity(self, other):
if 'similarity' in self.doc.getters_for_spans: if 'similarity' in self.doc.getters_for_spans: