Merge branch 'feature/prepare' of https://github.com/explosion/spaCy into feature/prepare

This commit is contained in:
Ines Montani 2020-09-29 16:59:35 +02:00
commit 1c60f0b5e9
8 changed files with 39 additions and 54 deletions

View File

@ -34,7 +34,7 @@ def init_labels_cli(
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
nlp = init_nlp(config, use_gpu=use_gpu, silent=False) nlp = init_nlp(config, use_gpu=use_gpu)
for name, component in nlp.pipeline: for name, component in nlp.pipeline:
if getattr(component, "label_data", None) is not None: if getattr(component, "label_data", None) is not None:
srsly.write_json(output_path / f"{name}.json", component.label_data) srsly.write_json(output_path / f"{name}.json", component.label_data)

View File

@ -56,7 +56,7 @@ def train_cli(
def init_pipeline( def init_pipeline(
config: Config, output_path: Optional[Path], *, use_gpu: int = -1 config: Config, output_path: Optional[Path], *, use_gpu: int = -1
) -> Language: ) -> Language:
init_kwargs = {"use_gpu": use_gpu, "silent": False} init_kwargs = {"use_gpu": use_gpu}
if output_path is not None: if output_path is not None:
init_path = output_path / "model-initial" init_path = output_path / "model-initial"
if not init_path.exists(): if not init_path.exists():
@ -74,12 +74,6 @@ def init_pipeline(
else: else:
msg.good(f"Loaded initialized pipeline from {init_path}") msg.good(f"Loaded initialized pipeline from {init_path}")
return nlp return nlp
msg.warn(
"Not saving initialized model: no output directory specified. "
"To speed up training, spaCy can save the initialized nlp object with "
"the vocabulary, vectors and label scheme. To take advantage of this, "
"provide an output directory."
)
return init_nlp(config, **init_kwargs) return init_nlp(config, **init_kwargs)

View File

@ -1181,24 +1181,9 @@ class Language:
) )
doc = Doc(self.vocab, words=["x", "y", "z"]) doc = Doc(self.vocab, words=["x", "y", "z"])
get_examples = lambda: [Example.from_dict(doc, {})] get_examples = lambda: [Example.from_dict(doc, {})]
# Populate vocab
if not hasattr(get_examples, "__call__"): if not hasattr(get_examples, "__call__"):
err = Errors.E930.format(name="Language", obj=type(get_examples)) err = Errors.E930.format(name="Language", obj=type(get_examples))
raise ValueError(err) raise ValueError(err)
valid_examples = False
for example in get_examples():
if not isinstance(example, Example):
err = Errors.E978.format(
name="Language.initialize", types=type(example)
)
raise ValueError(err)
else:
valid_examples = True
for word in [t.text for t in example.reference]:
_ = self.vocab[word] # noqa: F841
if not valid_examples:
err = Errors.E930.format(name="Language", obj="empty list")
raise ValueError(err)
# Make sure the config is interpolated so we can resolve subsections # Make sure the config is interpolated so we can resolve subsections
config = self.config.interpolate() config = self.config.interpolate()
# These are the settings provided in the [initialize] block in the config # These are the settings provided in the [initialize] block in the config

View File

@ -35,10 +35,7 @@ cdef class Pipe:
@property @property
def labels(self) -> Optional[Tuple[str]]: def labels(self) -> Optional[Tuple[str]]:
if "labels" in self.cfg: return []
return tuple(self.cfg["labels"])
else:
return None
@property @property
def label_data(self): def label_data(self):

View File

@ -266,7 +266,7 @@ class Tagger(Pipe):
raise ValueError("nan value when computing loss") raise ValueError("nan value when computing loss")
return float(loss), d_scores return float(loss), d_scores
def initialize(self, get_examples, *, nlp=None): def initialize(self, get_examples, *, nlp=None, labels=None):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -277,15 +277,19 @@ class Tagger(Pipe):
DOCS: https://nightly.spacy.io/api/tagger#initialize DOCS: https://nightly.spacy.io/api/tagger#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
if labels is not None:
for tag in labels:
self.add_label(tag)
else:
tags = set()
for example in get_examples():
for token in example.y:
if token.tag_:
tags.add(token.tag_)
for tag in sorted(tags):
self.add_label(tag)
doc_sample = [] doc_sample = []
label_sample = [] label_sample = []
tags = set()
for example in get_examples():
for token in example.y:
if token.tag_:
tags.add(token.tag_)
for tag in sorted(tags):
self.add_label(tag)
for example in islice(get_examples(), 10): for example in islice(get_examples(), 10):
doc_sample.append(example.x) doc_sample.append(example.x)
gold_tags = example.get_aligned("TAG", as_string=True) gold_tags = example.get_aligned("TAG", as_string=True)

View File

@ -160,16 +160,12 @@ class TextCategorizer(Pipe):
self.cfg["labels"] = tuple(value) self.cfg["labels"] = tuple(value)
@property @property
def label_data(self) -> Dict: def label_data(self) -> List[str]:
"""RETURNS (Dict): Information about the component's labels. """RETURNS (List[str]): Information about the component's labels.
DOCS: https://nightly.spacy.io/api/textcategorizer#labels DOCS: https://nightly.spacy.io/api/textcategorizer#labels
""" """
return { return self.labels
"labels": self.labels,
"positive": self.cfg["positive_label"],
"threshold": self.cfg["threshold"]
}
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]: def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents. This usually happens under """Apply the pipe to a stream of documents. This usually happens under
@ -354,6 +350,7 @@ class TextCategorizer(Pipe):
get_examples: Callable[[], Iterable[Example]], get_examples: Callable[[], Iterable[Example]],
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
labels: Optional[Dict] = None
): ):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -365,12 +362,14 @@ class TextCategorizer(Pipe):
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
subbatch = [] # Select a subbatch of examples to initialize the model if labels is None:
for example in islice(get_examples(), 10): for example in get_examples():
if len(subbatch) < 2: for cat in example.y.cats:
subbatch.append(example) self.add_label(cat)
for cat in example.y.cats: else:
self.add_label(cat) for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch] doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch) label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels() self._require_labels()

View File

@ -409,17 +409,20 @@ cdef class Parser(Pipe):
def set_output(self, nO): def set_output(self, nO):
self.model.attrs["resize_output"](self.model, nO) self.model.attrs["resize_output"](self.model, nO)
def initialize(self, get_examples, nlp=None): def initialize(self, get_examples, *, nlp=None, labels=None):
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {}) lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS: if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
langs = ", ".join(util.LEXEME_NORM_LANGS) langs = ", ".join(util.LEXEME_NORM_LANGS)
util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs)) util.logger.debug(Warnings.W033.format(model="parser or NER", langs=langs))
actions = self.moves.get_actions( if labels is not None:
examples=get_examples(), actions = dict(labels)
min_freq=self.cfg['min_action_freq'], else:
learn_tokens=self.cfg["learn_tokens"] actions = self.moves.get_actions(
) examples=get_examples(),
min_freq=self.cfg['min_action_freq'],
learn_tokens=self.cfg["learn_tokens"]
)
for action, labels in self.moves.labels.items(): for action, labels in self.moves.labels.items():
actions.setdefault(action, {}) actions.setdefault(action, {})
for label, freq in labels.items(): for label, freq in labels.items():

View File

@ -97,6 +97,9 @@ class registry(thinc.registry):
models = catalogue.create("spacy", "models", entry_points=True) models = catalogue.create("spacy", "models", entry_points=True)
cli = catalogue.create("spacy", "cli", entry_points=True) cli = catalogue.create("spacy", "cli", entry_points=True)
# We want json loading in the registry, so manually register srsly.read_json.
registry.readers("srsly.read_json.v0", srsly.read_json)
class SimpleFrozenDict(dict): class SimpleFrozenDict(dict):
"""Simplified implementation of a frozen dict, mainly used as default """Simplified implementation of a frozen dict, mainly used as default