mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #1400 from explosion/feature/sentence-parsing
💫 Force parser to respect preset sentence boundaries
This commit is contained in:
commit
689349e32f
|
@ -61,13 +61,13 @@ cdef struct TokenC:
|
||||||
attr_t sense
|
attr_t sense
|
||||||
int head
|
int head
|
||||||
attr_t dep
|
attr_t dep
|
||||||
bint sent_start
|
|
||||||
|
|
||||||
uint32_t l_kids
|
uint32_t l_kids
|
||||||
uint32_t r_kids
|
uint32_t r_kids
|
||||||
uint32_t l_edge
|
uint32_t l_edge
|
||||||
uint32_t r_edge
|
uint32_t r_edge
|
||||||
|
|
||||||
|
int sent_start
|
||||||
int ent_iob
|
int ent_iob
|
||||||
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||||
hash_t ent_id
|
hash_t ent_id
|
||||||
|
|
|
@ -307,6 +307,8 @@ cdef cppclass StateC:
|
||||||
this._stack[this._s_i] = this.B(0)
|
this._stack[this._s_i] = this.B(0)
|
||||||
this._s_i += 1
|
this._s_i += 1
|
||||||
this._b_i += 1
|
this._b_i += 1
|
||||||
|
if this.B_(0).sent_start == 1:
|
||||||
|
this.set_break(this.B(0))
|
||||||
if this._b_i > this._break:
|
if this._b_i > this._break:
|
||||||
this._break = -1
|
this._break = -1
|
||||||
|
|
||||||
|
@ -383,7 +385,7 @@ cdef cppclass StateC:
|
||||||
|
|
||||||
void set_break(int i) nogil:
|
void set_break(int i) nogil:
|
||||||
if 0 <= i < this.length:
|
if 0 <= i < this.length:
|
||||||
this._sent[i].sent_start = True
|
this._sent[i].sent_start = 1
|
||||||
this._break = this._b_i
|
this._break = this._b_i
|
||||||
|
|
||||||
void clone(const StateC* src) nogil:
|
void clone(const StateC* src) nogil:
|
||||||
|
|
|
@ -118,7 +118,7 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start
|
return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and st.B_(0).sent_start != 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
|
@ -178,7 +178,7 @@ cdef class Reduce:
|
||||||
cdef class LeftArc:
|
cdef class LeftArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
return not st.B_(0).sent_start
|
return st.B_(0).sent_start != 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
|
@ -212,7 +212,7 @@ cdef class LeftArc:
|
||||||
cdef class RightArc:
|
cdef class RightArc:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||||
return not st.B_(0).sent_start
|
return st.B_(0).sent_start != 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int transition(StateC* st, attr_t label) nogil:
|
cdef int transition(StateC* st, attr_t label) nogil:
|
||||||
|
@ -248,6 +248,10 @@ cdef class Break:
|
||||||
return False
|
return False
|
||||||
elif st.stack_depth() < 1:
|
elif st.stack_depth() < 1:
|
||||||
return False
|
return False
|
||||||
|
elif st.B_(0).l_edge < 0:
|
||||||
|
return False
|
||||||
|
elif st._sent[st.B_(0).l_edge].sent_start < 0:
|
||||||
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -244,8 +244,8 @@ cdef class Parser:
|
||||||
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128))
|
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 128))
|
||||||
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1))
|
parser_maxout_pieces = util.env_opt('parser_maxout_pieces', cfg.get('maxout_pieces', 1))
|
||||||
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
|
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 7000))
|
||||||
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 4))
|
hist_size = util.env_opt('history_feats', cfg.get('hist_size', 0))
|
||||||
hist_width = util.env_opt('history_width', cfg.get('hist_width', 16))
|
hist_width = util.env_opt('history_width', cfg.get('hist_width', 0))
|
||||||
if hist_size >= 1 and depth == 0:
|
if hist_size >= 1 and depth == 0:
|
||||||
raise ValueError("Inconsistent hyper-params: "
|
raise ValueError("Inconsistent hyper-params: "
|
||||||
"history_feats >= 1 but parser_hidden_depth==0")
|
"history_feats >= 1 but parser_hidden_depth==0")
|
||||||
|
|
73
spacy/tests/parser/test_preset_sbd.py
Normal file
73
spacy/tests/parser/test_preset_sbd.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
'''Test that the parser respects preset sentence boundaries.'''
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import pytest
|
||||||
|
from thinc.neural.optimizers import Adam
|
||||||
|
from thinc.neural.ops import NumpyOps
|
||||||
|
|
||||||
|
from ...attrs import NORM
|
||||||
|
from ...gold import GoldParse
|
||||||
|
from ...vocab import Vocab
|
||||||
|
from ...tokens import Doc
|
||||||
|
from ...pipeline import NeuralDependencyParser
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vocab():
|
||||||
|
return Vocab(lex_attr_getters={NORM: lambda s: s})
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parser(vocab):
|
||||||
|
parser = NeuralDependencyParser(vocab)
|
||||||
|
parser.cfg['token_vector_width'] = 4
|
||||||
|
parser.cfg['hidden_width'] = 32
|
||||||
|
#parser.add_label('right')
|
||||||
|
parser.add_label('left')
|
||||||
|
parser.begin_training([], **parser.cfg)
|
||||||
|
sgd = Adam(NumpyOps(), 0.001)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
losses = {}
|
||||||
|
doc = Doc(vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
gold = GoldParse(doc, heads=[1, 1, 3, 3],
|
||||||
|
deps=['left', 'ROOT', 'left', 'ROOT'])
|
||||||
|
parser.update([doc], [gold], sgd=sgd, losses=losses)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def test_no_sentences(parser):
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sents_1(parser):
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc[2].sent_start = True
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) >= 2
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc[1].sent_start = False
|
||||||
|
doc[2].sent_start = True
|
||||||
|
doc[3].sent_start = False
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sents_1_2(parser):
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc[1].sent_start = True
|
||||||
|
doc[2].sent_start = True
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_sents_1_3(parser):
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc[1].sent_start = True
|
||||||
|
doc[3].sent_start = True
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) == 4
|
||||||
|
doc = Doc(parser.vocab, words=['a', 'b', 'c', 'd'])
|
||||||
|
doc[1].sent_start = True
|
||||||
|
doc[2].sent_start = False
|
||||||
|
doc[3].sent_start = True
|
||||||
|
doc = parser(doc)
|
||||||
|
assert len(list(doc.sents)) == 3
|
|
@ -485,7 +485,7 @@ cdef class Doc:
|
||||||
cdef int i
|
cdef int i
|
||||||
start = 0
|
start = 0
|
||||||
for i in range(1, self.length):
|
for i in range(1, self.length):
|
||||||
if self.c[i].sent_start:
|
if self.c[i].sent_start == 1:
|
||||||
yield Span(self, start, i)
|
yield Span(self, start, i)
|
||||||
start = i
|
start = i
|
||||||
if start != self.length:
|
if start != self.length:
|
||||||
|
|
|
@ -111,6 +111,30 @@ cdef class Span:
|
||||||
for i in range(self.start, self.end):
|
for i in range(self.start, self.end):
|
||||||
yield self.doc[i]
|
yield self.doc[i]
|
||||||
|
|
||||||
|
def as_doc(self):
|
||||||
|
'''Create a Doc object view of the Span's data.
|
||||||
|
|
||||||
|
This is mostly useful for C-typed interfaces.
|
||||||
|
'''
|
||||||
|
cdef Doc doc = Doc(self.doc.vocab)
|
||||||
|
doc.length = self.end-self.start
|
||||||
|
doc.c = &self.doc.c[self.start]
|
||||||
|
doc.mem = self.doc.mem
|
||||||
|
doc.is_parsed = self.doc.is_parsed
|
||||||
|
doc.is_tagged = self.doc.is_tagged
|
||||||
|
doc.noun_chunks_iterator = self.doc.noun_chunks_iterator
|
||||||
|
doc.user_hooks = self.doc.user_hooks
|
||||||
|
doc.user_span_hooks = self.doc.user_span_hooks
|
||||||
|
doc.user_token_hooks = self.doc.user_token_hooks
|
||||||
|
doc.vector = self.vector
|
||||||
|
doc.vector_norm = self.vector_norm
|
||||||
|
for key, value in self.doc.cats.items():
|
||||||
|
if hasattr(key, '__len__') and len(key) == 3:
|
||||||
|
cat_start, cat_end, cat_label = key
|
||||||
|
if cat_start == self.start_char and cat_end == self.end_char:
|
||||||
|
doc.cats[cat_label] = value
|
||||||
|
return doc
|
||||||
|
|
||||||
def merge(self, *args, **attributes):
|
def merge(self, *args, **attributes):
|
||||||
"""Retokenize the document, such that the span is merged into a single
|
"""Retokenize the document, such that the span is merged into a single
|
||||||
token.
|
token.
|
||||||
|
|
|
@ -281,13 +281,21 @@ cdef class Token:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.c.sent_start
|
return self.c.sent_start
|
||||||
|
|
||||||
def __set__(self, bint value):
|
def __set__(self, value):
|
||||||
if self.doc.is_parsed:
|
if self.doc.is_parsed:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Refusing to write to token.sent_start if its document is parsed, '
|
'Refusing to write to token.sent_start if its document is parsed, '
|
||||||
'because this may cause inconsistent state. '
|
'because this may cause inconsistent state. '
|
||||||
'See https://github.com/spacy-io/spaCy/issues/235 for workarounds.')
|
'See https://github.com/spacy-io/spaCy/issues/235 for workarounds.')
|
||||||
self.c.sent_start = value
|
if value is None:
|
||||||
|
self.c.sent_start = 0
|
||||||
|
elif value is True:
|
||||||
|
self.c.sent_start = 1
|
||||||
|
elif value is False:
|
||||||
|
self.c.sent_start = -1
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid value for token.sent_start -- must be one of "
|
||||||
|
"None, True, False")
|
||||||
|
|
||||||
property lefts:
|
property lefts:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user