diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 5b7cffb6b..c4355f1a1 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -172,16 +172,21 @@ def train( nlp.disable_pipes(*other_pipes) for pipe in pipeline: if pipe not in nlp.pipe_names: - nlp.add_pipe(nlp.create_pipe(pipe)) + if pipe == "parser": + pipe_cfg = {"learn_tokens": learn_tokens} + else: + pipe_cfg = {} + nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) else: msg.text("Starting with blank model '{}'".format(lang)) lang_cls = util.get_lang_class(lang) nlp = lang_cls() for pipe in pipeline: - nlp.add_pipe(nlp.create_pipe(pipe)) - - if learn_tokens: - nlp.add_pipe(nlp.create_pipe("merge_subtokens")) + if pipe == "parser": + pipe_cfg = {"learn_tokens": learn_tokens} + else: + pipe_cfg = {} + nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) if vectors: msg.text("Loading vector from model '{}'".format(vectors)) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 8b15ebd38..90ccc2fbf 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1038,7 +1038,10 @@ cdef class DependencyParser(Parser): @property def postprocesses(self): - return [nonproj.deprojectivize] + output = [nonproj.deprojectivize] + if self.cfg.get("learn_tokens") is True: + output.append(merge_subtokens) + return tuple(output) def add_multitask_objective(self, target): if target == "cloze": diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index c5b4e4469..eb39124ce 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -362,8 +362,9 @@ cdef class ArcEager(TransitionSystem): label_freqs.pop(label) # Ensure these actions are present actions[BREAK].setdefault('ROOT', 0) - actions[RIGHT].setdefault('subtok', 0) - actions[LEFT].setdefault('subtok', 0) + if kwargs.get("learn_tokens") is True: + actions[RIGHT].setdefault('subtok', 0) + actions[LEFT].setdefault('subtok', 0) # Used for backoff actions[RIGHT].setdefault('dep', 0) actions[LEFT].setdefault('dep', 0) @@ -410,11 +411,23 @@ cdef class ArcEager(TransitionSystem): def preprocess_gold(self, GoldParse gold): if not self.has_gold(gold): return None + # Figure out whether we're using subtok + use_subtok = False + for action, labels in self.labels.items(): + if SUBTOK_LABEL in labels: + use_subtok = True + break for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)): # Missing values if head is None or dep is None: gold.c.heads[i] = i gold.c.has_dep[i] = False + elif dep == SUBTOK_LABEL and not use_subtok: + # If we're not doing the joint tokenization and parsing, + # regard these subtok labels as missing + gold.c.heads[i] = i + gold.c.labels[i] = 0 + gold.c.has_dep[i] = False else: if head > i: action = LEFT diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index fa1a41fa4..c4edef137 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -573,7 +573,8 @@ cdef class Parser: get_gold_tuples = lambda: gold_tuples cfg.setdefault('min_action_freq', 30) actions = self.moves.get_actions(gold_parses=get_gold_tuples(), - min_freq=cfg.get('min_action_freq', 30)) + min_freq=cfg.get('min_action_freq', 30), + learn_tokens=self.cfg.get("learn_tokens", False)) for action, labels in self.moves.labels.items(): actions.setdefault(action, {}) for label, freq in labels.items(): diff --git a/spacy/tests/regression/test_issue3830.py b/spacy/tests/regression/test_issue3830.py new file mode 100644 index 000000000..54ce10924 --- /dev/null +++ b/spacy/tests/regression/test_issue3830.py @@ -0,0 +1,20 @@ +from spacy.pipeline.pipes import DependencyParser +from spacy.vocab import Vocab + + +def test_issue3830_no_subtok(): + """Test that the parser doesn't have subtok label if not learn_tokens""" + parser = DependencyParser(Vocab()) + parser.add_label("nsubj") + assert "subtok" not in parser.labels + parser.begin_training(lambda: []) + assert "subtok" not in parser.labels + + +def test_issue3830_with_subtok(): + """Test that the parser does have subtok label if learn_tokens=True.""" + parser = DependencyParser(Vocab(), learn_tokens=True) + parser.add_label("nsubj") + assert "subtok" not in parser.labels + parser.begin_training(lambda: []) + assert "subtok" in parser.labels