mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
Merge branch 'whatif/arrow' of https://github.com/explosion/spaCy into whatif/arrow
This commit is contained in:
commit
8687fc64eb
|
@ -72,7 +72,7 @@ class Corpus:
|
||||||
i += 1
|
i += 1
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def train_dataset(self, nlp, shuffle=True):
|
def train_dataset(self, nlp, shuffle=True, **kwargs):
|
||||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
||||||
examples = self.make_examples(nlp, ref_docs)
|
examples = self.make_examples(nlp, ref_docs)
|
||||||
if shuffle:
|
if shuffle:
|
||||||
|
@ -80,7 +80,7 @@ class Corpus:
|
||||||
random.shuffle(examples)
|
random.shuffle(examples)
|
||||||
yield from examples
|
yield from examples
|
||||||
|
|
||||||
def dev_dataset(self, nlp):
|
def dev_dataset(self, nlp, **kwargs):
|
||||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
||||||
examples = self.make_examples(nlp, ref_docs)
|
examples = self.make_examples(nlp, ref_docs)
|
||||||
yield from examples
|
yield from examples
|
||||||
|
|
|
@ -196,8 +196,6 @@ cdef class Example:
|
||||||
next_ner = x_tags[i+1] if (i+1) < self.x.length else None
|
next_ner = x_tags[i+1] if (i+1) < self.x.length else None
|
||||||
if prev_ner == "O" or next_ner == "O":
|
if prev_ner == "O" or next_ner == "O":
|
||||||
x_tags[i] = "O"
|
x_tags[i] = "O"
|
||||||
#print("Y tags", y_tags)
|
|
||||||
#print("X tags", x_tags)
|
|
||||||
return x_tags
|
return x_tags
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
|
|
|
@ -584,7 +584,6 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for label, freq in list(label_freqs.items()):
|
for label, freq in list(label_freqs.items()):
|
||||||
if freq < min_freq:
|
if freq < min_freq:
|
||||||
label_freqs.pop(label)
|
label_freqs.pop(label)
|
||||||
print("Removing", action, label, freq)
|
|
||||||
# Ensure these actions are present
|
# Ensure these actions are present
|
||||||
actions[BREAK].setdefault('ROOT', 0)
|
actions[BREAK].setdefault('ROOT', 0)
|
||||||
if kwargs.get("learn_tokens") is True:
|
if kwargs.get("learn_tokens") is True:
|
||||||
|
@ -618,9 +617,29 @@ cdef class ArcEager(TransitionSystem):
|
||||||
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
keeps = [i for i, s in enumerate(states) if not s.is_final()]
|
||||||
states = [states[i] for i in keeps]
|
states = [states[i] for i in keeps]
|
||||||
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps]
|
||||||
|
for gold in golds:
|
||||||
|
self._replace_unseen_labels(gold)
|
||||||
n_steps = sum([len(s.queue) * 4 for s in states])
|
n_steps = sum([len(s.queue) * 4 for s in states])
|
||||||
return states, golds, n_steps
|
return states, golds, n_steps
|
||||||
|
|
||||||
|
def _replace_unseen_labels(self, ArcEagerGold gold):
|
||||||
|
backoff_label = self.strings["dep"]
|
||||||
|
root_label = self.strings["ROOT"]
|
||||||
|
left_labels = self.labels[LEFT]
|
||||||
|
right_labels = self.labels[RIGHT]
|
||||||
|
break_labels = self.labels[BREAK]
|
||||||
|
for i in range(gold.c.length):
|
||||||
|
if not is_head_unknown(&gold.c, i):
|
||||||
|
head = gold.c.heads[i]
|
||||||
|
label = self.strings[gold.c.labels[i]]
|
||||||
|
if head > i and label not in left_labels:
|
||||||
|
gold.c.labels[i] = backoff_label
|
||||||
|
elif head < i and label not in right_labels:
|
||||||
|
gold.c.labels[i] = backoff_label
|
||||||
|
elif head == i and label not in break_labels:
|
||||||
|
gold.c.labels[i] = root_label
|
||||||
|
return gold
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name_or_id) except *:
|
cdef Transition lookup_transition(self, object name_or_id) except *:
|
||||||
if isinstance(name_or_id, int):
|
if isinstance(name_or_id, int):
|
||||||
return self.c[name_or_id]
|
return self.c[name_or_id]
|
||||||
|
|
|
@ -61,7 +61,7 @@ cdef class TransitionSystem:
|
||||||
offset += len(doc)
|
offset += len(doc)
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def get_oracle_sequence(self, Example example):
|
def get_oracle_sequence(self, Example example, _debug=False):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.n_moves > 0
|
assert self.n_moves > 0
|
||||||
|
@ -70,6 +70,8 @@ cdef class TransitionSystem:
|
||||||
|
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
states, golds, n_steps = self.init_gold_batch([example])
|
states, golds, n_steps = self.init_gold_batch([example])
|
||||||
|
if not states:
|
||||||
|
return []
|
||||||
state = states[0]
|
state = states[0]
|
||||||
gold = golds[0]
|
gold = golds[0]
|
||||||
history = []
|
history = []
|
||||||
|
@ -82,6 +84,7 @@ cdef class TransitionSystem:
|
||||||
history.append(i)
|
history.append(i)
|
||||||
s0 = state.S(0)
|
s0 = state.S(0)
|
||||||
b0 = state.B(0)
|
b0 = state.B(0)
|
||||||
|
if _debug:
|
||||||
debug_log.append(" ".join((
|
debug_log.append(" ".join((
|
||||||
self.get_class_name(i),
|
self.get_class_name(i),
|
||||||
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
|
||||||
|
@ -91,6 +94,7 @@ cdef class TransitionSystem:
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
if _debug:
|
||||||
print("Actions")
|
print("Actions")
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
print(self.get_class_name(i))
|
print(self.get_class_name(i))
|
||||||
|
|
|
@ -65,6 +65,9 @@ def test_oracle_four_words(arc_eager, vocab):
|
||||||
words = ["a", "b", "c", "d"]
|
words = ["a", "b", "c", "d"]
|
||||||
heads = [1, 1, 3, 3]
|
heads = [1, 1, 3, 3]
|
||||||
deps = ["left", "ROOT", "left", "ROOT"]
|
deps = ["left", "ROOT", "left", "ROOT"]
|
||||||
|
for dep in deps:
|
||||||
|
arc_eager.add_action(2, dep) # Left
|
||||||
|
arc_eager.add_action(3, dep) # Right
|
||||||
actions = ["L-left", "B-ROOT", "L-left"]
|
actions = ["L-left", "B-ROOT", "L-left"]
|
||||||
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
state, cost_history = get_sequence_costs(arc_eager, words, heads, deps, actions)
|
||||||
assert state.is_final()
|
assert state.is_final()
|
||||||
|
@ -141,7 +144,7 @@ def test_get_oracle_actions():
|
||||||
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
|
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
|
||||||
config = {
|
config = {
|
||||||
"learn_tokens": False,
|
"learn_tokens": False,
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 0,
|
||||||
"beam_width": 1,
|
"beam_width": 1,
|
||||||
"beam_update_prob": 1.0,
|
"beam_update_prob": 1.0,
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,6 +92,9 @@ def test_get_oracle_moves_negative_O(tsys, vocab):
|
||||||
assert names
|
assert names
|
||||||
|
|
||||||
|
|
||||||
|
# We can't easily represent this on a Doc object. Not sure what the best solution
|
||||||
|
# would be, but I don't think it's an important use case?
|
||||||
|
@pytest.mark.xfail(reason="No longer supported")
|
||||||
def test_oracle_moves_missing_B(en_vocab):
|
def test_oracle_moves_missing_B(en_vocab):
|
||||||
words = ["B", "52", "Bomber"]
|
words = ["B", "52", "Bomber"]
|
||||||
biluo_tags = [None, None, "L-PRODUCT"]
|
biluo_tags = [None, None, "L-PRODUCT"]
|
||||||
|
@ -114,7 +117,9 @@ def test_oracle_moves_missing_B(en_vocab):
|
||||||
moves.add_action(move_types.index("U"), label)
|
moves.add_action(move_types.index("U"), label)
|
||||||
moves.get_oracle_sequence(example)
|
moves.get_oracle_sequence(example)
|
||||||
|
|
||||||
|
# We can't easily represent this on a Doc object. Not sure what the best solution
|
||||||
|
# would be, but I don't think it's an important use case?
|
||||||
|
@pytest.mark.xfail(reason="No longer supported")
|
||||||
def test_oracle_moves_whitespace(en_vocab):
|
def test_oracle_moves_whitespace(en_vocab):
|
||||||
words = ["production", "\n", "of", "Northrop", "\n", "Corp.", "\n", "'s", "radar"]
|
words = ["production", "\n", "of", "Northrop", "\n", "Corp.", "\n", "'s", "radar"]
|
||||||
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
|
biluo_tags = ["O", "O", "O", "B-ORG", None, "I-ORG", "L-ORG", "O", "O"]
|
||||||
|
|
|
@ -46,6 +46,8 @@ def test_parser_parse_one_word_sentence(en_tokenizer, en_parser, text):
|
||||||
assert doc[0].dep != 0
|
assert doc[0].dep != 0
|
||||||
|
|
||||||
|
|
||||||
|
# We removed the step_through API a while ago. we should bring it back though
|
||||||
|
@pytest.mark.xfail(reason="Unsupported")
|
||||||
def test_parser_initial(en_tokenizer, en_parser):
|
def test_parser_initial(en_tokenizer, en_parser):
|
||||||
text = "I ate the pizza with anchovies."
|
text = "I ate the pizza with anchovies."
|
||||||
# heads = [1, 0, 1, -2, -3, -1, -5]
|
# heads = [1, 0, 1, -2, -3, -1, -5]
|
||||||
|
@ -89,7 +91,8 @@ def test_parser_merge_pp(en_tokenizer):
|
||||||
assert doc[2].text == "another phrase"
|
assert doc[2].text == "another phrase"
|
||||||
assert doc[3].text == "occurs"
|
assert doc[3].text == "occurs"
|
||||||
|
|
||||||
|
# We removed the step_through API a while ago. we should bring it back though
|
||||||
|
@pytest.mark.xfail(reason="Unsupported")
|
||||||
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
|
def test_parser_arc_eager_finalize_state(en_tokenizer, en_parser):
|
||||||
text = "a b c d e"
|
text = "a b c d e"
|
||||||
|
|
||||||
|
|
|
@ -288,7 +288,7 @@ def test_issue1967(label):
|
||||||
"entities": [label],
|
"entities": [label],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert "JOB-NAME" in ner.moves.get_actions(gold_parses=[example])[1]
|
assert "JOB-NAME" in ner.moves.get_actions(examples=[example])[1]
|
||||||
|
|
||||||
|
|
||||||
def test_issue1971(en_vocab):
|
def test_issue1971(en_vocab):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user