spaCy/spacy/syntax/nn_parser.pyx

668 lines
27 KiB
Cython
Raw Normal View History

# cython: infer_types=True, cdivision=True, boundscheck=False
cimport cython.parallel
cimport numpy as np
2017-10-27 20:45:57 +03:00
from cpython.ref cimport PyObject, Py_XDECREF
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
2017-10-27 20:45:57 +03:00
from libc.math cimport exp
from libcpp.vector cimport vector
2017-11-15 02:51:42 +03:00
from libc.string cimport memset, memcpy
2017-10-27 20:45:57 +03:00
from libc.stdlib cimport calloc, free
from cymem.cymem cimport Pool
from thinc.extra.search cimport Beam
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
from thinc.backends.linalg cimport Vec, VecVec
2020-02-18 17:38:18 +03:00
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
from thinc.api import get_array_module, zero_init, set_dropout_rate
from itertools import islice
import srsly
2020-02-18 17:38:18 +03:00
import numpy.random
import numpy
2020-02-28 14:20:23 +03:00
import warnings
from ..tokens.doc cimport Doc
from ..gold cimport GoldParse
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ._parser_model cimport get_c_weights, get_c_sizes
2017-10-27 20:45:57 +03:00
from .stateclass cimport StateClass
from ._state cimport StateC
from .transition_system cimport Transition
from . cimport _beam_utils
from ..gold import Example
from ..util import link_vectors_to_models, create_default_optimizer, registry
from ..compat import copy_array
from ..errors import Errors, Warnings
from .. import util
from ._parser_model import ParserModel
from . import _beam_utils
from . import nonproj
2017-11-14 04:11:40 +03:00
cdef class Parser:
"""
Base class of the DependencyParser and EntityRecognizer.
"""
name = 'base_parser'
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
def __init__(self, Vocab vocab, model, **cfg):
2017-10-27 15:39:30 +03:00
"""Create a Parser.
vocab (Vocab): The vocabulary object. Must be shared with documents
to be processed. The value is set to the `.vocab` attribute.
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
**cfg: Configuration parameters. Set to the `.cfg` attribute.
If it doesn't include a value for 'moves', a new instance is
created with `self.TransitionSystem()`. This defines how the
parse-state is created, updated and evaluated.
"""
self.vocab = vocab
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
moves = cfg.get("moves", None)
if moves is None:
# defined by EntityRecognizer as a BiluoPushDown
moves = self.TransitionSystem(self.vocab.strings)
self.moves = moves
cfg.setdefault('min_action_freq', 30)
cfg.setdefault('learn_tokens', False)
cfg.setdefault('beam_width', 1)
cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used)
self.model = model
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
self.set_output(self.moves.n_moves)
self.cfg = cfg
self._multitasks = []
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
self._rehearsal_model = None
@classmethod
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
def from_nlp(cls, nlp, model, **cfg):
return cls(nlp.vocab, model, **cfg)
def __reduce__(self):
return (Parser, (self.vocab, self.model), self.moves)
def __getstate__(self):
return self.moves
def __setstate__(self, moves):
self.moves = moves
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
@property
def move_names(self):
names = []
for i in range(self.moves.n_moves):
name = self.moves.move_name(self.moves.c[i].move, self.moves.c[i].label)
# Explicitly removing the internal "U-" token used for blocking entities
if name != "U-":
names.append(name)
return names
@property
def labels(self):
class_names = [self.moves.get_class_name(i) for i in range(self.moves.n_moves)]
return class_names
@property
def tok2vec(self):
'''Return the embedding and convolutional layer of the model.'''
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
return self.model.tok2vec
@property
def postprocesses(self):
# Available for subclasses, e.g. to deprojectivize
return []
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def add_label(self, label):
resized = False
for action in self.moves.action_types:
added = self.moves.add_action(action, label)
if added:
resized = True
if resized:
self._resize()
def _resize(self):
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
self.model.resize_output(self.moves.n_moves)
if self._rehearsal_model not in (True, False, None):
self._rehearsal_model.resize_output(self.moves.n_moves)
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def add_multitask_objective(self, target):
# Defined in subclasses, to avoid circular import
raise NotImplementedError
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
'''Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.
For instance, the dependency parser can benefit from sharing
an input representation with a label prediction model. These auxiliary
models are discarded after training.
'''
pass
def preprocess_gold(self, examples):
for ex in examples:
yield ex
def use_params(self, params):
# Can't decorate cdef class :(. Workaround.
with self.model.use_params(params):
yield
def __call__(self, Doc doc, beam_width=None):
2017-10-27 15:39:30 +03:00
"""Apply the parser or entity recognizer, setting the annotations onto
the `Doc` object.
2017-10-27 15:39:30 +03:00
doc (Doc): The document to be processed.
"""
if beam_width is None:
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
beam_width = self.cfg['beam_width']
beam_density = self.cfg.get('beam_density', 0.)
states = self.predict([doc], beam_width=beam_width,
beam_density=beam_density)
self.set_annotations([doc], states, tensors=None)
return doc
Fix Example details for train CLI / pipeline components (#4624) * Switch to train_dataset() function in train CLI * Fixes for pipe() methods in pipeline components * Don't clobber `examples` variable with `as_example` in pipe() methods * Remove unnecessary traversals of `examples` * Update Parser.pipe() for Examples * Add `as_examples` kwarg to `pipe()` with implementation to return `Example`s * Accept `Doc` or `Example` in `pipe()` with `_get_doc()` (copied from `Pipe`) * Fixes to Example implementation in spacy.gold * Move `make_projective` from an attribute of Example to an argument of `Example.get_gold_parses()` * Head of 0 are not treated as unset * Unset heads are set to self rather than `None` (which causes problems while projectivizing) * Check for `Doc` (not just not `None`) when creating GoldParses for pre-merged example * Don't clobber `examples` variable in `iter_gold_docs()` * Add/modify gold tests for handling projectivity * In JSON roundtrip compare results from `dev_dataset` rather than `train_dataset` to avoid projectivization (and other potential modifications) * Add test for projective train vs. nonprojective dev versions of the same `Doc` * Handle ignore_misaligned as arg rather than attr Move `ignore_misaligned` from an attribute of `Example` to an argument to `Example.get_gold_parses()`, which makes it parallel to `make_projective`. Add test with old and new align that checks whether `ignore_misaligned` errors are raised as expected (only for new align). * Remove unused attrs from gold.pxd Remove `ignore_misaligned` and `make_projective` from `gold.pxd` * Refer to Example.goldparse in iter_gold_docs() Use `Example.goldparse` in `iter_gold_docs()` instead of `Example.gold` because a `None` `GoldParse` is generated with ignore_misaligned and generating it on-the-fly can raise an unwanted AlignmentError * Update test for ignore_misaligned
2019-11-23 16:32:15 +03:00
def pipe(self, docs, int batch_size=256, int n_threads=-1, beam_width=None,
as_example=False):
2017-10-27 15:39:30 +03:00
"""Process a stream of documents.
stream: The sequence of documents to process.
batch_size (int): Number of documents to accumulate into a working set.
YIELDS (Doc): Documents, in order.
"""
if beam_width is None:
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
beam_width = self.cfg['beam_width']
beam_density = self.cfg.get('beam_density', 0.)
cdef Doc doc
2018-12-03 04:19:12 +03:00
for batch in util.minibatch(docs, size=batch_size):
2017-11-14 04:11:40 +03:00
batch_in_order = list(batch)
Fix Example details for train CLI / pipeline components (#4624) * Switch to train_dataset() function in train CLI * Fixes for pipe() methods in pipeline components * Don't clobber `examples` variable with `as_example` in pipe() methods * Remove unnecessary traversals of `examples` * Update Parser.pipe() for Examples * Add `as_examples` kwarg to `pipe()` with implementation to return `Example`s * Accept `Doc` or `Example` in `pipe()` with `_get_doc()` (copied from `Pipe`) * Fixes to Example implementation in spacy.gold * Move `make_projective` from an attribute of Example to an argument of `Example.get_gold_parses()` * Head of 0 are not treated as unset * Unset heads are set to self rather than `None` (which causes problems while projectivizing) * Check for `Doc` (not just not `None`) when creating GoldParses for pre-merged example * Don't clobber `examples` variable in `iter_gold_docs()` * Add/modify gold tests for handling projectivity * In JSON roundtrip compare results from `dev_dataset` rather than `train_dataset` to avoid projectivization (and other potential modifications) * Add test for projective train vs. nonprojective dev versions of the same `Doc` * Handle ignore_misaligned as arg rather than attr Move `ignore_misaligned` from an attribute of `Example` to an argument to `Example.get_gold_parses()`, which makes it parallel to `make_projective`. Add test with old and new align that checks whether `ignore_misaligned` errors are raised as expected (only for new align). * Remove unused attrs from gold.pxd Remove `ignore_misaligned` and `make_projective` from `gold.pxd` * Refer to Example.goldparse in iter_gold_docs() Use `Example.goldparse` in `iter_gold_docs()` instead of `Example.gold` because a `None` `GoldParse` is generated with ignore_misaligned and generating it on-the-fly can raise an unwanted AlignmentError * Update test for ignore_misaligned
2019-11-23 16:32:15 +03:00
docs = [self._get_doc(ex) for ex in batch_in_order]
by_length = sorted(docs, key=lambda doc: len(doc))
2019-03-10 18:01:34 +03:00
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
subbatch = list(subbatch)
parse_states = self.predict(subbatch, beam_width=beam_width,
beam_density=beam_density)
2017-11-14 04:11:40 +03:00
self.set_annotations(subbatch, parse_states, tensors=None)
Fix Example details for train CLI / pipeline components (#4624) * Switch to train_dataset() function in train CLI * Fixes for pipe() methods in pipeline components * Don't clobber `examples` variable with `as_example` in pipe() methods * Remove unnecessary traversals of `examples` * Update Parser.pipe() for Examples * Add `as_examples` kwarg to `pipe()` with implementation to return `Example`s * Accept `Doc` or `Example` in `pipe()` with `_get_doc()` (copied from `Pipe`) * Fixes to Example implementation in spacy.gold * Move `make_projective` from an attribute of Example to an argument of `Example.get_gold_parses()` * Head of 0 are not treated as unset * Unset heads are set to self rather than `None` (which causes problems while projectivizing) * Check for `Doc` (not just not `None`) when creating GoldParses for pre-merged example * Don't clobber `examples` variable in `iter_gold_docs()` * Add/modify gold tests for handling projectivity * In JSON roundtrip compare results from `dev_dataset` rather than `train_dataset` to avoid projectivization (and other potential modifications) * Add test for projective train vs. nonprojective dev versions of the same `Doc` * Handle ignore_misaligned as arg rather than attr Move `ignore_misaligned` from an attribute of `Example` to an argument to `Example.get_gold_parses()`, which makes it parallel to `make_projective`. Add test with old and new align that checks whether `ignore_misaligned` errors are raised as expected (only for new align). * Remove unused attrs from gold.pxd Remove `ignore_misaligned` and `make_projective` from `gold.pxd` * Refer to Example.goldparse in iter_gold_docs() Use `Example.goldparse` in `iter_gold_docs()` instead of `Example.gold` because a `None` `GoldParse` is generated with ignore_misaligned and generating it on-the-fly can raise an unwanted AlignmentError * Update test for ignore_misaligned
2019-11-23 16:32:15 +03:00
if as_example:
annotated_examples = []
for ex, doc in zip(batch_in_order, docs):
ex.doc = doc
annotated_examples.append(ex)
yield from annotated_examples
else:
yield from batch_in_order
def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.):
if isinstance(docs, Doc):
docs = [docs]
2018-06-29 20:21:38 +03:00
if not any(len(doc) for doc in docs):
result = self.moves.init_batch(docs)
self._resize()
return result
if beam_width < 2:
return self.greedy_parse(docs, drop=drop)
else:
return self.beam_parse(docs, beam_width=beam_width,
beam_density=beam_density, drop=drop)
def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states
cdef StateClass state
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
model = self.model.predict(docs)
weights = get_c_weights(model)
for state in batch:
if not state.is_final():
states.push_back(state.c)
sizes = get_c_sizes(model, states.size())
2017-10-19 01:25:21 +03:00
with nogil:
self._parseC(&states[0],
weights, sizes)
return batch
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
2017-11-15 02:51:42 +03:00
cdef np.ndarray token_ids
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_dropout_rate(self.model, drop)
beams = self.moves.init_beams(docs, beam_width, beam_density=beam_density)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
cdef int nr_feature = self.model.lower.get_dim("nF")
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
model = self.model.predict(docs)
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
2017-11-15 02:51:42 +03:00
dtype='i', order='C')
cdef int* c_ids
cdef int n_states
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
model = self.model.predict(docs)
todo = [beam for beam in beams if not beam.is_done]
2017-11-15 02:51:42 +03:00
while todo:
token_ids.fill(-1)
c_ids = <int*>token_ids.data
n_states = 0
for beam in todo:
for i in range(beam.size):
2017-11-15 02:51:42 +03:00
state = <StateC*>beam.at(i)
# This way we avoid having to score finalized states
# We do have to take care to keep indexes aligned, though
2017-11-15 02:51:42 +03:00
if not state.is_final():
state.set_context_tokens(c_ids, nr_feature)
c_ids += nr_feature
n_states += 1
if n_states == 0:
break
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
vectors = model.state2vec.predict(token_ids[:n_states])
scores = model.vec2scores.predict(vectors)
todo = self.transition_beams(todo, scores)
return beams
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
cdef void _parseC(self, StateC** states,
WeightsC weights, SizesC sizes) nogil:
cdef int i, j
cdef vector[StateC*] unfinished
cdef ActivationsC activations = alloc_activations(sizes)
while sizes.states >= 1:
predict_states(&activations,
states, &weights, sizes)
# Validate actions, argmax, take action.
self.c_transition_batch(states,
activations.scores, sizes.classes, sizes.states)
for i in range(sizes.states):
if not states[i].is_final():
unfinished.push_back(states[i])
for i in range(unfinished.size()):
states[i] = unfinished[i]
sizes.states = unfinished.size()
unfinished.clear()
free_activations(&activations)
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def set_annotations(self, docs, states_or_beams, tensors=None):
cdef StateClass state
cdef Beam beam
cdef Doc doc
states = []
beams = []
for state_or_beam in states_or_beams:
if isinstance(state_or_beam, StateClass):
states.append(state_or_beam)
else:
beam = state_or_beam
state = StateClass.borrow(<StateC*>beam.at(0))
states.append(state)
beams.append(beam)
for i, (state, doc) in enumerate(zip(states, docs)):
self.moves.finalize_state(state.c)
for j in range(doc.length):
doc.c[j] = state.c._sent[j]
self.moves.finalize_doc(doc)
for hook in self.postprocesses:
hook(doc)
for beam in beams:
_beam_utils.cleanup_beam(beam)
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def transition_states(self, states, float[:, ::1] scores):
cdef StateClass state
cdef float* c_scores = &scores[0, 0]
cdef vector[StateC*] c_states
for state in states:
c_states.push_back(state.c)
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
return [state for state in states if not state.c.is_final()]
cdef void c_transition_batch(self, StateC** states, const float* scores,
int nr_class, int batch_size) nogil:
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
with gil:
assert self.moves.n_moves > 0
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
cdef int i, guess
cdef Transition action
for i in range(batch_size):
self.moves.set_valid(is_valid, states[i])
guess = arg_max_if_valid(&scores[i*nr_class], is_valid, nr_class)
if guess == -1:
# This shouldn't happen, but it's hard to raise an error here,
# and we don't want to infinite loop. So, force to end state.
states[i].force_final()
else:
action = self.moves.c[guess]
action.do(states[i], action.label)
states[i].push_hist(guess)
free(is_valid)
def transition_beams(self, beams, float[:, ::1] scores):
cdef Beam beam
cdef float* c_scores = &scores[0, 0]
for beam in beams:
for i in range(beam.size):
state = <StateC*>beam.at(i)
if not state.is_final():
self.moves.set_valid(beam.is_valid[i], state)
memcpy(beam.scores[i], c_scores, scores.shape[1] * sizeof(float))
c_scores += scores.shape[1]
2019-03-15 17:22:38 +03:00
beam.advance(_beam_utils.transition_state, _beam_utils.hash_state, <void*>self.moves.c)
beam.check_done(_beam_utils.check_final_state, NULL)
return [b for b in beams if not b.is_done]
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
examples = Example.to_example_objects(examples)
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd)
# The probability we use beam update, instead of falling back to
# a greedy update
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
beam_update_prob = self.cfg['beam_update_prob']
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(examples, self.cfg['beam_width'],
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
drop=drop, sgd=sgd, losses=losses, set_annotations=set_annotations,
2019-03-17 23:47:45 +03:00
beam_density=self.cfg.get('beam_density', 0.001))
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_dropout_rate(self.model, drop)
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
states_golds = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final() and g is not None]
# Prepare the stepwise model, and get the callback for finishing the batch
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
model, backprop_tok2vec = self.model.begin_update([ex.doc for ex in examples])
all_states = list(states)
for _ in range(max_steps):
if not states_golds:
break
states, golds = zip(*states_golds)
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
states_golds = [eg for eg in states_golds if not eg[0].is_final()]
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
backprop_tok2vec(golds)
if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
docs = [ex.doc for ex in examples]
self.set_annotations(docs, all_states)
return losses
def rehearse(self, examples, sgd=None, losses=None, **cfg):
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
examples = Example.to_example_objects(examples)
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
if losses is None:
losses = {}
for multitask in self._multitasks:
if hasattr(multitask, 'rehearse'):
multitask.rehearse(examples, losses=losses, sgd=sgd)
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
if self._rehearsal_model is None:
return None
losses.setdefault(self.name, 0.)
docs = [ex.doc for ex in examples]
states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
# Prepare the stepwise model, and get the callback for finishing the batch
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_dropout_rate(self._rehearsal_model, 0.0)
set_dropout_rate(self.model, 0.0)
tutor, _ = self._rehearsal_model.begin_update(docs)
model, finish_update = self.model.begin_update(docs)
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
n_scores = 0.
loss = 0.
2018-12-18 15:48:10 +03:00
while states:
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
targets, _ = tutor.begin_update(states)
guesses, backprop = model.begin_update(states)
d_scores = (guesses - targets) / targets.shape[0]
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
# If all weights for an output are 0 in the original model, don't
# supervise that output. This allows us to add classes.
loss += (d_scores**2).sum()
backprop(d_scores, sgd=sgd)
# Follow the predicted action
self.transition_states(states, guesses)
states = [state for state in states if not state.is_final()]
n_scores += d_scores.size
# Do the backprop
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
finish_update(docs)
if sgd is not None:
self.model.finish_update(sgd)
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
losses[self.name] += loss / n_scores
return losses
def update_beam(self, examples, width, drop=0., sgd=None, losses=None,
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_annotations=False, beam_density=0.0):
examples = Example.to_example_objects(examples)
docs = [ex.doc for ex in examples]
golds = [ex.gold for ex in examples]
new_golds = []
2017-08-13 01:15:16 +03:00
lengths = [len(d) for d in docs]
2017-08-14 02:02:05 +03:00
states = self.moves.init_batch(docs)
for gold in golds:
self.moves.preprocess_gold(gold)
new_golds.append(gold)
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
set_dropout_rate(self.model, drop)
model, backprop_tok2vec = self.model.begin_update(docs)
states_d_scores, backprops, beams = _beam_utils.update_beam(
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
self.moves, self.model.lower.get_dim("nF"), 10000, states, golds,
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
model.state2vec, model.vec2scores, width, losses=losses,
beam_density=beam_density)
2017-08-12 22:47:45 +03:00
for i, d_scores in enumerate(states_d_scores):
💫 Better support for semi-supervised learning (#3035) The new spacy pretrain command implemented BERT/ULMFit/etc-like transfer learning, using our Language Modelling with Approximate Outputs version of BERT's cloze task. Pretraining is convenient, but in some ways it's a bit of a strange solution. All we're doing is initialising the weights. At the same time, we're putting a lot of work into our optimisation so that it's less sensitive to initial conditions, and more likely to find good optima. I discuss this a bit in the pseudo-rehearsal blog post: https://explosion.ai/blog/pseudo-rehearsal-catastrophic-forgetting Support semi-supervised learning in spacy train One obvious way to improve these pretraining methods is to do multi-task learning, instead of just transfer learning. This has been shown to work very well: https://arxiv.org/pdf/1809.08370.pdf . This patch makes it easy to do this sort of thing. Add a new argument to spacy train, --raw-text. This takes a jsonl file with unlabelled data that can be used in arbitrary ways to do semi-supervised learning. Add a new method to the Language class and to pipeline components, .rehearse(). This is like .update(), but doesn't expect GoldParse objects. It takes a batch of Doc objects, and performs an update on some semi-supervised objective. Move the BERT-LMAO objective out from spacy/cli/pretrain.py into spacy/_ml.py, so we can create a new pipeline component, ClozeMultitask. This can be specified as a parser or NER multitask in the spacy train command. Example usage: python -m spacy train en ./tmp ~/data/en-core-web/train/nw.json ~/data/en-core-web/dev/nw.json --pipeline parser --raw-textt ~/data/unlabelled/reddit-100k.jsonl --vectors en_vectors_web_lg --parser-multitasks cloze Implement rehearsal methods for pipeline components The new --raw-text argument and nlp.rehearse() method also gives us a good place to implement the the idea in the pseudo-rehearsal blog post in the parser. This works as follows: Add a new nlp.resume_training() method. This allocates copies of pre-trained models in the pipeline, setting things up for the rehearsal updates. It also returns an optimizer object. This also greatly reduces confusion around the nlp.begin_training() method, which randomises the weights, making it not suitable for adding new labels or otherwise fine-tuning a pre-trained model. Implement rehearsal updates on the Parser class, making it available for the dependency parser and NER. During rehearsal, the initial model is used to supervise the model being trained. The current model is asked to match the predictions of the initial model on some data. This minimises catastrophic forgetting, by keeping the model's predictions close to the original. See the blog post for details. Implement rehearsal updates for tagger Implement rehearsal updates for text categoriz
2018-12-10 18:25:33 +03:00
losses[self.name] += (d_scores**2).mean()
2017-08-12 22:47:45 +03:00
ids, bp_vectors, bp_scores = backprops[i]
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
d_vector = bp_scores(d_scores)
if isinstance(model.ops, CupyOps) \
and not isinstance(ids, model.state2vec.ops.xp.ndarray):
model.backprops.append((
util.get_async(model.cuda_stream, ids),
util.get_async(model.cuda_stream, d_vector),
2017-08-13 01:15:16 +03:00
bp_vectors))
else:
model.backprops.append((ids, d_vector, bp_vectors))
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
backprop_tok2vec(golds)
if sgd is not None:
self.model.finish_update(sgd)
if set_annotations:
self.set_annotations(docs, beams)
cdef Beam beam
for beam in beams:
_beam_utils.cleanup_beam(beam)
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
def get_gradients(self):
"""Get non-zero gradients of the model's parameters, as a dictionary
keyed by the parameter ID. The values are (weights, gradients) tuples.
"""
gradients = {}
queue = [self.model]
seen = set()
for node in queue:
if node.id in seen:
continue
seen.add(node.id)
if hasattr(node, "_mem") and node._mem.gradient.any():
gradients[node.id] = [node._mem.weights, node._mem.gradient]
if hasattr(node, "_layers"):
queue.extend(node._layers)
return gradients
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
2017-05-26 19:31:23 +03:00
cdef:
StateClass state
Transition action
whole_docs = [ex.doc for ex in whole_examples]
whole_golds = [ex.gold for ex in whole_examples]
2017-05-26 19:31:23 +03:00
whole_states = self.moves.init_batch(whole_docs)
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
2017-05-28 01:59:00 +03:00
max_moves = 0
states = []
2017-05-26 19:31:23 +03:00
golds = []
for doc, state, gold in zip(whole_docs, whole_states, whole_golds):
gold = self.moves.preprocess_gold(gold)
2017-05-26 19:31:23 +03:00
if gold is None:
continue
oracle_actions = self.moves.get_oracle_sequence(doc, gold)
start = 0
while start < len(doc):
2017-05-26 19:31:23 +03:00
state = state.copy()
2017-05-28 01:59:00 +03:00
n_moves = 0
while state.B(0) < start and not state.is_final():
2017-05-26 19:31:23 +03:00
action = self.moves.c[oracle_actions.pop(0)]
action.do(state.c, action.label)
state.c.push_hist(action.clas)
2017-05-28 01:59:00 +03:00
n_moves += 1
2017-05-26 19:31:23 +03:00
has_gold = self.moves.has_gold(gold, start=start,
end=start+max_length)
if not state.is_final() and has_gold:
states.append(state)
2017-05-26 19:31:23 +03:00
golds.append(gold)
2017-05-28 01:59:00 +03:00
max_moves = max(max_moves, n_moves)
2017-05-26 19:31:23 +03:00
start += min(max_length, len(doc)-start)
2017-05-28 01:59:00 +03:00
max_moves = max(max_moves, len(oracle_actions))
return states, golds, max_moves
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
cdef StateClass state
cdef GoldParse gold
cdef Pool mem = Pool()
cdef int i
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.moves.n_moves > 0
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
dtype='f', order='C')
c_d_scores = <float*>d_scores.data
for i, (state, gold) in enumerate(zip(states, golds)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, gold)
for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in self.model.unseen_classes:
self.model.unseen_classes.remove(j)
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
if losses is not None:
losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum()
return d_scores
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
def create_optimizer(self):
Update spaCy for thinc 8.0.0 (#4920) * Add load_from_config function * Add train_from_config script * Merge configs and expose via spacy.config * Fix script * Suggest create_evaluation_callback * Hard-code for NER * Fix errors * Register command * Add TODO * Update train-from-config todos * Fix imports * Allow delayed setting of parser model nr_class * Get train-from-config working * Tidy up and fix scores and printing * Hide traceback if cancelled * Fix weighted score formatting * Fix score formatting * Make output_path optional * Add Tok2Vec component * Tidy up and add tok2vec_tensors * Add option to copy docs in nlp.update * Copy docs in nlp.update * Adjust nlp.update() for set_annotations * Don't shuffle pipes in nlp.update, decruft * Support set_annotations arg in component update * Support set_annotations in parser update * Add get_gradients method * Add get_gradients to parser * Update errors.py * Fix problems caused by merge * Add _link_components method in nlp * Add concept of 'listeners' and ControlledModel * Support optional attributes arg in ControlledModel * Try having tok2vec component in pipeline * Fix tok2vec component * Fix config * Fix tok2vec * Update for Example * Update for Example * Update config * Add eg2doc util * Update and add schemas/types * Update schemas * Fix nlp.update * Fix tagger * Remove hacks from train-from-config * Remove hard-coded config str * Calculate loss in tok2vec component * Tidy up and use function signatures instead of models * Support union types for registry models * Minor cleaning in Language.update * Make ControlledModel specifically Tok2VecListener * Fix train_from_config * Fix tok2vec * Tidy up * Add function for bilstm tok2vec * Fix type * Fix syntax * Fix pytorch optimizer * Add example configs * Update for thinc describe changes * Update for Thinc changes * Update for dropout/sgd changes * Update for dropout/sgd changes * Unhack gradient update * Work on refactoring _ml * Remove _ml.py module * WIP upgrade cli scripts for thinc * Move some _ml stuff to util * Import link_vectors from util * Update train_from_config * Import from util * Import from util * Temporarily add ml.component_models module * Move ml methods * Move typedefs * Update load vectors * Update gitignore * Move imports * Add PrecomputableAffine * Fix imports * Fix imports * Fix imports * Fix missing imports * Update CLI scripts * Update spacy.language * Add stubs for building the models * Update model definition * Update create_default_optimizer * Fix import * Fix comment * Update imports in tests * Update imports in spacy.cli * Fix import * fix obsolete thinc imports * update srsly pin * from thinc to ml_datasets for example data such as imdb * update ml_datasets pin * using STATE.vectors * small fix * fix Sentencizer.pipe * black formatting * rename Affine to Linear as in thinc * set validate explicitely to True * rename with_square_sequences to with_list2padded * rename with_flatten to with_list2array * chaining layernorm * small fixes * revert Optimizer import * build_nel_encoder with new thinc style * fixes using model's get and set methods * Tok2Vec in component models, various fixes * fix up legacy tok2vec code * add model initialize calls * add in build_tagger_model * small fixes * setting model dims * fixes for ParserModel * various small fixes * initialize thinc Models * fixes * consistent naming of window_size * fixes, removing set_dropout * work around Iterable issue * remove legacy tok2vec * util fix * fix forward function of tok2vec listener * more fixes * trying to fix PrecomputableAffine (not succesful yet) * alloc instead of allocate * add morphologizer * rename residual * rename fixes * Fix predict function * Update parser and parser model * fixing few more tests * Fix precomputable affine * Update component model * Update parser model * Move backprop padding to own function, for test * Update test * Fix p. affine * Update NEL * build_bow_text_classifier and extract_ngrams * Fix parser init * Fix test add label * add build_simple_cnn_text_classifier * Fix parser init * Set gpu off by default in example * Fix tok2vec listener * Fix parser model * Small fixes * small fix for PyTorchLSTM parameters * revert my_compounding hack (iterable fixed now) * fix biLSTM * Fix uniqued * PyTorchRNNWrapper fix * small fixes * use helper function to calculate cosine loss * small fixes for build_simple_cnn_text_classifier * putting dropout default at 0.0 to ensure the layer gets built * using thinc util's set_dropout_rate * moving layer normalization inside of maxout definition to optimize dropout * temp debugging in NEL * fixed NEL model by using init defaults ! * fixing after set_dropout_rate refactor * proper fix * fix test_update_doc after refactoring optimizers in thinc * Add CharacterEmbed layer * Construct tagger Model * Add missing import * Remove unused stuff * Work on textcat * fix test (again :)) after optimizer refactor * fixes to allow reading Tagger from_disk without overwriting dimensions * don't build the tok2vec prematuraly * fix CharachterEmbed init * CharacterEmbed fixes * Fix CharacterEmbed architecture * fix imports * renames from latest thinc update * one more rename * add initialize calls where appropriate * fix parser initialization * Update Thinc version * Fix errors, auto-format and tidy up imports * Fix validation * fix if bias is cupy array * revert for now * ensure it's a numpy array before running bp in ParserStepModel * no reason to call require_gpu twice * use CupyOps.to_numpy instead of cupy directly * fix initialize of ParserModel * remove unnecessary import * fixes for CosineDistance * fix device renaming * use refactored loss functions (Thinc PR 251) * overfitting test for tagger * experimental settings for the tagger: avoid zero-init and subword normalization * clean up tagger overfitting test * use previous default value for nP * remove toy config * bringing layernorm back (had a bug - fixed in thinc) * revert setting nP explicitly * remove setting default in constructor * restore values as they used to be * add overfitting test for NER * add overfitting test for dep parser * add overfitting test for textcat * fixing init for linear (previously affine) * larger eps window for textcat * ensure doc is not None * Require newer thinc * Make float check vaguer * Slop the textcat overfit test more * Fix textcat test * Fix exclusive classes for textcat * fix after renaming of alloc methods * fixing renames and mandatory arguments (staticvectors WIP) * upgrade to thinc==8.0.0.dev3 * refer to vocab.vectors directly instead of its name * rename alpha to learn_rate * adding hashembed and staticvectors dropout * upgrade to thinc 8.0.0.dev4 * add name back to avoid warning W020 * thinc dev4 * update srsly * using thinc 8.0.0a0 ! Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Ines Montani <ines@ines.io>
2020-01-29 19:06:46 +03:00
return create_default_optimizer()
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
def set_output(self, nO):
if self.model.upper.has_dim("nO") is None:
self.model.upper.set_dim("nO", nO)
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
self.cfg.update(kwargs)
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
actions = self.moves.get_actions(gold_parses=get_examples(),
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
min_freq=self.cfg['min_action_freq'],
learn_tokens=self.cfg["learn_tokens"])
for action, labels in self.moves.labels.items():
actions.setdefault(action, {})
for label, freq in labels.items():
if label not in actions[action]:
actions[action][label] = freq
Improve label management in parser and NER (#2108) This patch does a few smallish things that tighten up the training workflow a little, and allow memory use during training to be reduced by letting the GoldCorpus stream data properly. Previously, the parser and entity recognizer read and saved labels as lists, with extra labels noted separately. Lists were used becaue ordering is very important, to ensure that the label-to-class mapping is stable. We now manage labels as nested dictionaries, first keyed by the action, and then keyed by the label. Values are frequencies. The trick is, how do we save new labels? We need to make sure we iterate over these in the same order they're added. Otherwise, we'll get different class IDs, and the model's predictions won't make sense. To allow stable sorting, we map the new labels to negative values. If we have two new labels, they'll be noted as having "frequency" -1 and -2. The next new label will then have "frequency" -3. When we sort by (frequency, label), we then get a stable sort. Storing frequencies then allows us to make the next nice improvement. Previously we had to iterate over the whole training set, to pre-process it for the deprojectivisation. This led to storing the whole training set in memory. This was most of the required memory during training. To prevent this, we now store the frequencies as we stream in the data, and deprojectivize as we go. Once we've built the frequencies, we can then apply a frequency cut-off when we decide how many classes to make. Finally, to allow proper data streaming, we also have to have some way of shuffling the iterator. This is awkward if the training files have multiple documents in them. To solve this, the GoldCorpus class now writes the training data to disk in msgpack files, one per document. We can then shuffle the data by shuffling the paths. This is a squash merge, as I made a lot of very small commits. Individual commit messages below. * Simplify label management for TransitionSystem and its subclasses * Fix serialization for new label handling format in parser * Simplify and improve GoldCorpus class. Reduce memory use, write to temp dir * Set actions in transition system * Require thinc 6.11.1.dev4 * Fix error in parser init * Add unicode declaration * Fix unicode declaration * Update textcat test * Try to get model training on less memory * Print json loc for now * Try rapidjson to reduce memory use * Remove rapidjson requirement * Try rapidjson for reduced mem usage * Handle None heads when projectivising * Stream json docs * Fix train script * Handle projectivity in GoldParse * Fix projectivity handling * Add minibatch_by_words util from ud_train * Minibatch by number of words in spacy.cli.train * Move minibatch_by_words util to spacy.util * Fix label handling * More hacking at label management in parser * Fix encoding in msgpack serialization in GoldParse * Adjust batch sizes in parser training * Fix minibatch_by_words * Add merge_subtokens function to pipeline.pyx * Register merge_subtokens factory * Restore use of msgpack tmp directory * Use minibatch-by-words in train * Handle retokenization in scorer * Change back-off approach for missing labels. Use 'dep' label * Update NER for new label management * Set NER tags for over-segmented words * Fix label alignment in gold * Fix label back-off for infrequent labels * Fix int type in labels dict key * Fix int type in labels dict key * Update feature definition for 8 feature set * Update ud-train script for new label stuff * Fix json streamer * Print the line number if conll eval fails * Update children and sentence boundaries after deprojectivisation * Export set_children_from_heads from doc.pxd * Render parses during UD training * Remove print statement * Require thinc 6.11.1.dev6. Try adding wheel as install_requires * Set different dev version, to flush pip cache * Update thinc version * Update GoldCorpus docs * Remove print statements * Fix formatting and links [ci skip]
2018-03-19 04:58:08 +03:00
self.moves.initialize_actions(actions)
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
# make sure we resize so we have an appropriate upper layer
self._resize()
if sgd is None:
sgd = self.create_optimizer()
doc_sample = []
gold_sample = []
for example in islice(get_examples(), 1000):
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
for doc, gold in parses:
doc_sample.append(doc)
gold_sample.append(gold)
self.model.initialize(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
link_vectors_to_models(self.vocab)
return sgd
💫 Port master changes over to develop (#2979) * Create aryaprabhudesai.md (#2681) * Update _install.jade (#2688) Typo fix: "models" -> "model" * Add FAC to spacy.explain (resolves #2706) * Remove docstrings for deprecated arguments (see #2703) * When calling getoption() in conftest.py, pass a default option (#2709) * When calling getoption() in conftest.py, pass a default option This is necessary to allow testing an installed spacy by running: pytest --pyargs spacy * Add contributor agreement * update bengali token rules for hyphen and digits (#2731) * Less norm computations in token similarity (#2730) * Less norm computations in token similarity * Contributor agreement * Remove ')' for clarity (#2737) Sorry, don't mean to be nitpicky, I just noticed this when going through the CLI and thought it was a quick fix. That said, if this was intention than please let me know. * added contributor agreement for mbkupfer (#2738) * Basic support for Telugu language (#2751) * Lex _attrs for polish language (#2750) * Signed spaCy contributor agreement * Added polish version of english lex_attrs * Introduces a bulk merge function, in order to solve issue #653 (#2696) * Fix comment * Introduce bulk merge to increase performance on many span merges * Sign contributor agreement * Implement pull request suggestions * Describe converters more explicitly (see #2643) * Add multi-threading note to Language.pipe (resolves #2582) [ci skip] * Fix formatting * Fix dependency scheme docs (closes #2705) [ci skip] * Don't set stop word in example (closes #2657) [ci skip] * Add words to portuguese language _num_words (#2759) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Update Indonesian model (#2752) * adding e-KTP in tokenizer exceptions list * add exception token * removing lines with containing space as it won't matter since we use .split() method in the end, added new tokens in exception * add tokenizer exceptions list * combining base_norms with norm_exceptions * adding norm_exception * fix double key in lemmatizer * remove unused import on punctuation.py * reformat stop_words to reduce number of lines, improve readibility * updating tokenizer exception * implement is_currency for lang/id * adding orth_first_upper in tokenizer_exceptions * update the norm_exception list * remove bunch of abbreviations * adding contributors file * Fixed spaCy+Keras example (#2763) * bug fixes in keras example * created contributor agreement * Adding French hyphenated first name (#2786) * Fix typo (closes #2784) * Fix typo (#2795) [ci skip] Fixed typo on line 6 "regcognizer --> recognizer" * Adding basic support for Sinhala language. (#2788) * adding Sinhala language package, stop words, examples and lex_attrs. * Adding contributor agreement * Updating contributor agreement * Also include lowercase norm exceptions * Fix error (#2802) * Fix error ValueError: cannot resize an array that references or is referenced by another array in this way. Use the resize function * added spaCy Contributor Agreement * Add charlax's contributor agreement (#2805) * agreement of contributor, may I introduce a tiny pl languge contribution (#2799) * Contributors agreement * Contributors agreement * Contributors agreement * Add jupyter=True to displacy.render in documentation (#2806) * Revert "Also include lowercase norm exceptions" This reverts commit 70f4e8adf37cfcfab60be2b97d6deae949b30e9e. * Remove deprecated encoding argument to msgpack * Set up dependency tree pattern matching skeleton (#2732) * Fix bug when too many entity types. Fixes #2800 * Fix Python 2 test failure * Require older msgpack-numpy * Restore encoding arg on msgpack-numpy * Try to fix version pin for msgpack-numpy * Update Portuguese Language (#2790) * Add words to portuguese language _num_words * Add words to portuguese language _num_words * Portuguese - Add/remove stopwords, fix tokenizer, add currency symbols * Extended punctuation and norm_exceptions in the Portuguese language * Correct error in spacy universe docs concerning spacy-lookup (#2814) * Update Keras Example for (Parikh et al, 2016) implementation (#2803) * bug fixes in keras example * created contributor agreement * baseline for Parikh model * initial version of parikh 2016 implemented * tested asymmetric models * fixed grevious error in normalization * use standard SNLI test file * begin to rework parikh example * initial version of running example * start to document the new version * start to document the new version * Update Decompositional Attention.ipynb * fixed calls to similarity * updated the README * import sys package duh * simplified indexing on mapping word to IDs * stupid python indent error * added code from https://github.com/tensorflow/tensorflow/issues/3388 for tf bug workaround * Fix typo (closes #2815) [ci skip] * Update regex version dependency * Set version to 2.0.13.dev3 * Skip seemingly problematic test * Remove problematic test * Try previous version of regex * Revert "Remove problematic test" This reverts commit bdebbef45552d698d390aa430b527ee27830f11b. * Unskip test * Try older version of regex * 💫 Update training examples and use minibatching (#2830) <!--- Provide a general summary of your changes in the title. --> ## Description Update the training examples in `/examples/training` to show usage of spaCy's `minibatch` and `compounding` helpers ([see here](https://spacy.io/usage/training#tips-batch-size) for details). The lack of batching in the examples has caused some confusion in the past, especially for beginners who would copy-paste the examples, update them with large training sets and experienced slow and unsatisfying results. ### Types of change enhancements ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Visual C++ link updated (#2842) (closes #2841) [ci skip] * New landing page * Add contribution agreement * Correcting lang/ru/examples.py (#2845) * Correct some grammatical inaccuracies in lang\ru\examples.py; filled Contributor Agreement * Correct some grammatical inaccuracies in lang\ru\examples.py * Move contributor agreement to separate file * Set version to 2.0.13.dev4 * Add Persian(Farsi) language support (#2797) * Also include lowercase norm exceptions * Remove in favour of https://github.com/explosion/spaCy/graphs/contributors * Rule-based French Lemmatizer (#2818) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> Add a rule-based French Lemmatizer following the english one and the excellent PR for [greek language optimizations](https://github.com/explosion/spaCy/pull/2558) to adapt the Lemmatizer class. ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> - Lemma dictionary used can be found [here](http://infolingu.univ-mlv.fr/DonneesLinguistiques/Dictionnaires/telechargement.html), I used the XML version. - Add several files containing exhaustive list of words for each part of speech - Add some lemma rules - Add POS that are not checked in the standard Lemmatizer, i.e PRON, DET, ADV and AUX - Modify the Lemmatizer class to check in lookup table as a last resort if POS not mentionned - Modify the lemmatize function to check in lookup table as a last resort - Init files are updated so the model can support all the functionalities mentioned above - Add words to tokenizer_exceptions_list.py in respect to regex used in tokenizer_exceptions.py ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [X] I have submitted the spaCy Contributor Agreement. - [X] I ran the tests, and all new and existing tests passed. - [X] My changes don't require a change to the documentation, or if they do, I've added all required information. * Set version to 2.0.13 * Fix formatting and consistency * Update docs for new version [ci skip] * Increment version [ci skip] * Add info on wheels [ci skip] * Adding "This is a sentence" example to Sinhala (#2846) * Add wheels badge * Update badge [ci skip] * Update README.rst [ci skip] * Update murmurhash pin * Increment version to 2.0.14.dev0 * Update GPU docs for v2.0.14 * Add wheel to setup_requires * Import prefer_gpu and require_gpu functions from Thinc * Add tests for prefer_gpu() and require_gpu() * Update requirements and setup.py * Workaround bug in thinc require_gpu * Set version to v2.0.14 * Update push-tag script * Unhack prefer_gpu * Require thinc 6.10.6 * Update prefer_gpu and require_gpu docs [ci skip] * Fix specifiers for GPU * Set version to 2.0.14.dev1 * Set version to 2.0.14 * Update Thinc version pin * Increment version * Fix msgpack-numpy version pin * Increment version * Update version to 2.0.16 * Update version [ci skip] * Redundant ')' in the Stop words' example (#2856) <!--- Provide a general summary of your changes in the title. --> ## Description <!--- Use this section to describe your changes. If your changes required testing, include information about the testing environment and the tests you ran. If your test fixes a bug reported in an issue, don't forget to include the issue number. If your PR is still a work in progress, that's totally fine – just include a note to let us know. --> ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [ ] I have submitted the spaCy Contributor Agreement. - [ ] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information. * Documentation improvement regarding joblib and SO (#2867) Some documentation improvements ## Description 1. Fixed the dead URL to joblib 2. Fixed Stack Overflow brand name (with space) ### Types of change Documentation ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * raise error when setting overlapping entities as doc.ents (#2880) * Fix out-of-bounds access in NER training The helper method state.B(1) gets the index of the first token of the buffer, or -1 if no such token exists. Normally this is safe because we pass this to functions like state.safe_get(), which returns an empty token. Here we used it directly as an array index, which is not okay! This error may have been the cause of out-of-bounds access errors during training. Similar errors may still be around, so much be hunted down. Hunting this one down took a long time...I printed out values across training runs and diffed, looking for points of divergence between runs, when no randomness should be allowed. * Change PyThaiNLP Url (#2876) * Fix missing comma * Add example showing a fix-up rule for space entities * Set version to 2.0.17.dev0 * Update regex version * Revert "Update regex version" This reverts commit 62358dd867d15bc6a475942dff34effba69dd70a. * Try setting older regex version, to align with conda * Set version to 2.0.17 * Add spacy-js to universe [ci-skip] * Add spacy-raspberry to universe (closes #2889) * Add script to validate universe json [ci skip] * Removed space in docs + added contributor indo (#2909) * - removed unneeded space in documentation * - added contributor info * Allow input text of length up to max_length, inclusive (#2922) * Include universe spec for spacy-wordnet component (#2919) * feat: include universe spec for spacy-wordnet component * chore: include spaCy contributor agreement * Minor formatting changes [ci skip] * Fix image [ci skip] Twitter URL doesn't work on live site * Check if the word is in one of the regular lists specific to each POS (#2886) * 💫 Create random IDs for SVGs to prevent ID clashes (#2927) Resolves #2924. ## Description Fixes problem where multiple visualizations in Jupyter notebooks would have clashing arc IDs, resulting in weirdly positioned arc labels. Generating a random ID prefix so even identical parses won't receive the same IDs for consistency (even if effect of ID clash isn't noticable here.) ### Types of change bug fix ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix typo [ci skip] * fixes symbolic link on py3 and windows (#2949) * fixes symbolic link on py3 and windows during setup of spacy using command python -m spacy link en_core_web_sm en closes #2948 * Update spacy/compat.py Co-Authored-By: cicorias <cicorias@users.noreply.github.com> * Fix formatting * Update universe [ci skip] * Catalan Language Support (#2940) * Catalan language Support * Ddding Catalan to documentation * Sort languages alphabetically [ci skip] * Update tests for pytest 4.x (#2965) <!--- Provide a general summary of your changes in the title. --> ## Description - [x] Replace marks in params for pytest 4.0 compat ([see here](https://docs.pytest.org/en/latest/deprecations.html#marks-in-pytest-mark-parametrize)) - [x] Un-xfail passing tests (some fixes in a recent update resolved a bunch of issues, but tests were apparently never updated here) ### Types of change <!-- What type of change does your PR cover? Is it a bug fix, an enhancement or new feature, or a change to the documentation? --> ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information. * Fix regex pin to harmonize with conda (#2964) * Update README.rst * Fix bug where Vocab.prune_vector did not use 'batch_size' (#2977) Fixes #2976 * Fix typo * Fix typo * Remove duplicate file * Require thinc 7.0.0.dev2 Fixes bug in gpu_ops that would use cupy instead of numpy on CPU * Add missing import * Fix error IDs * Fix tests
2018-11-29 18:30:29 +03:00
Fix Example details for train CLI / pipeline components (#4624) * Switch to train_dataset() function in train CLI * Fixes for pipe() methods in pipeline components * Don't clobber `examples` variable with `as_example` in pipe() methods * Remove unnecessary traversals of `examples` * Update Parser.pipe() for Examples * Add `as_examples` kwarg to `pipe()` with implementation to return `Example`s * Accept `Doc` or `Example` in `pipe()` with `_get_doc()` (copied from `Pipe`) * Fixes to Example implementation in spacy.gold * Move `make_projective` from an attribute of Example to an argument of `Example.get_gold_parses()` * Head of 0 are not treated as unset * Unset heads are set to self rather than `None` (which causes problems while projectivizing) * Check for `Doc` (not just not `None`) when creating GoldParses for pre-merged example * Don't clobber `examples` variable in `iter_gold_docs()` * Add/modify gold tests for handling projectivity * In JSON roundtrip compare results from `dev_dataset` rather than `train_dataset` to avoid projectivization (and other potential modifications) * Add test for projective train vs. nonprojective dev versions of the same `Doc` * Handle ignore_misaligned as arg rather than attr Move `ignore_misaligned` from an attribute of `Example` to an argument to `Example.get_gold_parses()`, which makes it parallel to `make_projective`. Add test with old and new align that checks whether `ignore_misaligned` errors are raised as expected (only for new align). * Remove unused attrs from gold.pxd Remove `ignore_misaligned` and `make_projective` from `gold.pxd` * Refer to Example.goldparse in iter_gold_docs() Use `Example.goldparse` in `iter_gold_docs()` instead of `Example.gold` because a `None` `GoldParse` is generated with ignore_misaligned and generating it on-the-fly can raise an unwanted AlignmentError * Update test for ignore_misaligned
2019-11-23 16:32:15 +03:00
def _get_doc(self, example):
""" Use this method if the `example` can be both a Doc or an Example """
if isinstance(example, Doc):
return example
return example.doc
def to_disk(self, path, exclude=tuple(), **kwargs):
serializers = {
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
'vocab': lambda p: self.vocab.to_disk(p),
'moves': lambda p: self.moves.to_disk(p, exclude=["strings"]),
'cfg': lambda p: srsly.write_json(p, self.cfg)
}
exclude = util.get_serialization_exclude(serializers, exclude, kwargs)
util.to_disk(path, serializers, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
deserializers = {
'vocab': lambda p: self.vocab.from_disk(p),
'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]),
'cfg': lambda p: self.cfg.update(srsly.read_json(p)),
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
'model': lambda p: None,
}
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
util.from_disk(path, deserializers, exclude)
if 'model' not in exclude:
path = util.ensure_path(path)
with (path / 'model').open('rb') as file_:
2017-05-31 14:42:39 +03:00
bytes_data = file_.read()
try:
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
self._resize()
self.model.from_bytes(bytes_data)
except AttributeError:
raise ValueError(Errors.E149)
return self
def to_bytes(self, exclude=tuple(), **kwargs):
serializers = {
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
"model": lambda: (self.model.to_bytes()),
"vocab": lambda: self.vocab.to_bytes(),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
}
exclude = util.get_serialization_exclude(serializers, exclude, kwargs)
return util.to_bytes(serializers, exclude)
2017-05-29 11:14:20 +03:00
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b),
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
Default settings to configurations (#4995) * fix grad_clip naming * cleaning up pretrained_vectors out of cfg * further refactoring Model init's * move Model building out of pipes * further refactor to require a model config when creating a pipe * small fixes * making cfg in nn_parser more consistent * fixing nr_class for parser * fixing nn_parser's nO * fix printing of loss * architectures in own file per type, consistent naming * convenience methods default_tagger_config and default_tok2vec_config * let create_pipe access default config if available for that component * default_parser_config * move defaults to separate folder * allow reading nlp from package or dir with argument 'name' * architecture spacy.VocabVectors.v1 to read static vectors from file * cleanup * default configs for nel, textcat, morphologizer, tensorizer * fix imports * fixing unit tests * fixes and clean up * fixing defaults, nO, fix unit tests * restore parser IO * fix IO * 'fix' serialization test * add *.cfg to manifest * fix example configs with additional arguments * replace Morpohologizer with Tagger * add IO bit when testing overfitting of tagger (currently failing) * fix IO - don't initialize when reading from disk * expand overfitting tests to also check IO goes OK * remove dropout from HashEmbed to fix Tagger performance * add defaults for sentrec * update thinc * always pass a Model instance to a Pipe * fix piped_added statement * remove obsolete W029 * remove obsolete errors * restore byte checking tests (work again) * clean up test * further test cleanup * convert from config to Model in create_pipe * bring back error when component is not initialized * cleanup * remove calls for nlp2.begin_training * use thinc.api in imports * allow setting charembed's nM and nC * fix for hardcoded nM/nC + unit test * formatting fixes * trigger build
2020-02-27 20:42:27 +03:00
"model": lambda b: None,
}
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
msg = util.from_bytes(bytes_data, deserializers, exclude)
2017-05-29 11:14:20 +03:00
if 'model' not in exclude:
if 'model' in msg:
try:
self.model.from_bytes(msg['model'])
except AttributeError:
raise ValueError(Errors.E149)
2017-05-29 11:14:20 +03:00
return self