mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Allow more components to use labels
This commit is contained in:
		
							parent
							
								
									99bff78617
								
							
						
					
					
						commit
						1fd002180e
					
				|  | @ -160,16 +160,12 @@ class TextCategorizer(Pipe): | ||||||
|         self.cfg["labels"] = tuple(value) |         self.cfg["labels"] = tuple(value) | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def label_data(self) -> Dict: |     def label_data(self) -> List[str]: | ||||||
|         """RETURNS (Dict): Information about the component's labels. |         """RETURNS (List[str]): Information about the component's labels. | ||||||
| 
 | 
 | ||||||
|         DOCS: https://nightly.spacy.io/api/textcategorizer#labels |         DOCS: https://nightly.spacy.io/api/textcategorizer#labels | ||||||
|         """ |         """ | ||||||
|         return { |         return self.labels | ||||||
|             "labels": self.labels, |  | ||||||
|             "positive": self.cfg["positive_label"], |  | ||||||
|             "threshold": self.cfg["threshold"] |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|     def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]: |     def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]: | ||||||
|         """Apply the pipe to a stream of documents. This usually happens under |         """Apply the pipe to a stream of documents. This usually happens under | ||||||
|  | @ -354,6 +350,7 @@ class TextCategorizer(Pipe): | ||||||
|         get_examples: Callable[[], Iterable[Example]], |         get_examples: Callable[[], Iterable[Example]], | ||||||
|         *, |         *, | ||||||
|         nlp: Optional[Language] = None, |         nlp: Optional[Language] = None, | ||||||
|  |         labels: Optional[Dict] = None | ||||||
|     ): |     ): | ||||||
|         """Initialize the pipe for training, using a representative set |         """Initialize the pipe for training, using a representative set | ||||||
|         of data examples. |         of data examples. | ||||||
|  | @ -365,12 +362,14 @@ class TextCategorizer(Pipe): | ||||||
|         DOCS: https://nightly.spacy.io/api/textcategorizer#initialize |         DOCS: https://nightly.spacy.io/api/textcategorizer#initialize | ||||||
|         """ |         """ | ||||||
|         self._ensure_examples(get_examples) |         self._ensure_examples(get_examples) | ||||||
|         subbatch = []  # Select a subbatch of examples to initialize the model |         if labels is None: | ||||||
|         for example in islice(get_examples(), 10): |             for example in get_examples(): | ||||||
|             if len(subbatch) < 2: |                 for cat in example.y.cats: | ||||||
|                 subbatch.append(example) |                     self.add_label(cat) | ||||||
|             for cat in example.y.cats: |         else: | ||||||
|                 self.add_label(cat) |             for label in labels: | ||||||
|  |                 self.add_label(label) | ||||||
|  |         subbatch = list(islice(get_examples(), 10)) | ||||||
|         doc_sample = [eg.reference for eg in subbatch] |         doc_sample = [eg.reference for eg in subbatch] | ||||||
|         label_sample, _ = self._examples_to_truth(subbatch) |         label_sample, _ = self._examples_to_truth(subbatch) | ||||||
|         self._require_labels() |         self._require_labels() | ||||||
|  |  | ||||||
|  | @ -409,17 +409,20 @@ cdef class Parser(Pipe): | ||||||
|     def set_output(self, nO): |     def set_output(self, nO): | ||||||
|         self.model.attrs["resize_output"](self.model, nO) |         self.model.attrs["resize_output"](self.model, nO) | ||||||
| 
 | 
 | ||||||
|     def initialize(self, get_examples, nlp=None): |     def initialize(self, get_examples, *, nlp=None, labels=None): | ||||||
|         self._ensure_examples(get_examples) |         self._ensure_examples(get_examples) | ||||||
|         lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) |         lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) | ||||||
|         if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: |         if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: | ||||||
|             langs = ", ".join(util.LEXEME_NORM_LANGS) |             langs = ", ".join(util.LEXEME_NORM_LANGS) | ||||||
|             util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs)) |             util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs)) | ||||||
|         actions = self.moves.get_actions( |         if labels is not None: | ||||||
|             examples=get_examples(), |             actions = dict(labels) | ||||||
|             min_freq=self.cfg['min_action_freq'], |         else: | ||||||
|             learn_tokens=self.cfg["learn_tokens"] |             actions = self.moves.get_actions( | ||||||
|         ) |                 examples=get_examples(), | ||||||
|  |                 min_freq=self.cfg['min_action_freq'], | ||||||
|  |                 learn_tokens=self.cfg["learn_tokens"] | ||||||
|  |             ) | ||||||
|         for action, labels in self.moves.labels.items(): |         for action, labels in self.moves.labels.items(): | ||||||
|             actions.setdefault(action, {}) |             actions.setdefault(action, {}) | ||||||
|             for label, freq in labels.items(): |             for label, freq in labels.items(): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user