mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Prevent subtok label if not learning tokens The parser introduces the subtok label to mark tokens that should be merged during post-processing. Previously this happened even if we did not have the --learn-tokens flag set. This patch passes the config through to the parser, to prevent the problem. * Make merge_subtokens a parser post-process if learn_subtokens * Fix train script * Add test for 3830: subtok problem * Fix handlign of non-subtok in parser training
This commit is contained in:
		
							parent
							
								
									c417c380e3
								
							
						
					
					
						commit
						bb911e5f4e
					
				|  | @ -172,16 +172,21 @@ def train( | |||
|         nlp.disable_pipes(*other_pipes) | ||||
|         for pipe in pipeline: | ||||
|             if pipe not in nlp.pipe_names: | ||||
|                 nlp.add_pipe(nlp.create_pipe(pipe)) | ||||
|                 if pipe == "parser": | ||||
|                     pipe_cfg = {"learn_tokens": learn_tokens} | ||||
|                 else: | ||||
|                     pipe_cfg = {} | ||||
|                 nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) | ||||
|     else: | ||||
|         msg.text("Starting with blank model '{}'".format(lang)) | ||||
|         lang_cls = util.get_lang_class(lang) | ||||
|         nlp = lang_cls() | ||||
|         for pipe in pipeline: | ||||
|             nlp.add_pipe(nlp.create_pipe(pipe)) | ||||
| 
 | ||||
|     if learn_tokens: | ||||
|         nlp.add_pipe(nlp.create_pipe("merge_subtokens")) | ||||
|             if pipe == "parser": | ||||
|                 pipe_cfg = {"learn_tokens": learn_tokens} | ||||
|             else: | ||||
|                 pipe_cfg = {} | ||||
|             nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg)) | ||||
| 
 | ||||
|     if vectors: | ||||
|         msg.text("Loading vector from model '{}'".format(vectors)) | ||||
|  |  | |||
|  | @ -1038,7 +1038,10 @@ cdef class DependencyParser(Parser): | |||
| 
 | ||||
|     @property | ||||
|     def postprocesses(self): | ||||
|         return [nonproj.deprojectivize] | ||||
|         output = [nonproj.deprojectivize] | ||||
|         if self.cfg.get("learn_tokens") is True: | ||||
|             output.append(merge_subtokens) | ||||
|         return tuple(output) | ||||
| 
 | ||||
|     def add_multitask_objective(self, target): | ||||
|         if target == "cloze": | ||||
|  |  | |||
|  | @ -362,8 +362,9 @@ cdef class ArcEager(TransitionSystem): | |||
|                         label_freqs.pop(label) | ||||
|         # Ensure these actions are present | ||||
|         actions[BREAK].setdefault('ROOT', 0) | ||||
|         actions[RIGHT].setdefault('subtok', 0) | ||||
|         actions[LEFT].setdefault('subtok', 0) | ||||
|         if kwargs.get("learn_tokens") is True: | ||||
|             actions[RIGHT].setdefault('subtok', 0) | ||||
|             actions[LEFT].setdefault('subtok', 0) | ||||
|         # Used for backoff | ||||
|         actions[RIGHT].setdefault('dep', 0) | ||||
|         actions[LEFT].setdefault('dep', 0) | ||||
|  | @ -410,11 +411,23 @@ cdef class ArcEager(TransitionSystem): | |||
|     def preprocess_gold(self, GoldParse gold): | ||||
|         if not self.has_gold(gold): | ||||
|             return None | ||||
|         # Figure out whether we're using subtok | ||||
|         use_subtok = False | ||||
|         for action, labels in self.labels.items(): | ||||
|             if SUBTOK_LABEL in labels: | ||||
|                 use_subtok = True | ||||
|                 break | ||||
|         for i, (head, dep) in enumerate(zip(gold.heads, gold.labels)): | ||||
|             # Missing values | ||||
|             if head is None or dep is None: | ||||
|                 gold.c.heads[i] = i | ||||
|                 gold.c.has_dep[i] = False | ||||
|             elif dep == SUBTOK_LABEL and not use_subtok: | ||||
|                 # If we're not doing the joint tokenization and parsing, | ||||
|                 # regard these subtok labels as missing | ||||
|                 gold.c.heads[i] = i | ||||
|                 gold.c.labels[i] = 0 | ||||
|                 gold.c.has_dep[i] = False | ||||
|             else: | ||||
|                 if head > i: | ||||
|                     action = LEFT | ||||
|  |  | |||
|  | @ -573,7 +573,8 @@ cdef class Parser: | |||
|             get_gold_tuples = lambda: gold_tuples | ||||
|         cfg.setdefault('min_action_freq', 30) | ||||
|         actions = self.moves.get_actions(gold_parses=get_gold_tuples(), | ||||
|                                          min_freq=cfg.get('min_action_freq', 30)) | ||||
|                                          min_freq=cfg.get('min_action_freq', 30), | ||||
|                                          learn_tokens=self.cfg.get("learn_tokens", False)) | ||||
|         for action, labels in self.moves.labels.items(): | ||||
|             actions.setdefault(action, {}) | ||||
|             for label, freq in labels.items(): | ||||
|  |  | |||
							
								
								
									
										20
									
								
								spacy/tests/regression/test_issue3830.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								spacy/tests/regression/test_issue3830.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,20 @@ | |||
| from spacy.pipeline.pipes import DependencyParser | ||||
| from spacy.vocab import Vocab | ||||
| 
 | ||||
| 
 | ||||
| def test_issue3830_no_subtok(): | ||||
|     """Test that the parser doesn't have subtok label if not learn_tokens""" | ||||
|     parser = DependencyParser(Vocab()) | ||||
|     parser.add_label("nsubj") | ||||
|     assert "subtok" not in parser.labels | ||||
|     parser.begin_training(lambda: []) | ||||
|     assert "subtok" not in parser.labels | ||||
| 
 | ||||
| 
 | ||||
| def test_issue3830_with_subtok(): | ||||
|     """Test that the parser does have subtok label if learn_tokens=True.""" | ||||
|     parser = DependencyParser(Vocab(), learn_tokens=True) | ||||
|     parser.add_label("nsubj") | ||||
|     assert "subtok" not in parser.labels | ||||
|     parser.begin_training(lambda: []) | ||||
|     assert "subtok" in parser.labels | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user