diff --git a/setup.cfg b/setup.cfg index 95ada08ef..0495ff258 100644 --- a/setup.cfg +++ b/setup.cfg @@ -86,6 +86,8 @@ cuda101 = cupy-cuda101>=5.0.0b4,<9.0.0 cuda102 = cupy-cuda102>=5.0.0b4,<9.0.0 +cuda110 = + cupy-cuda110>=5.0.0b4,<9.0.0 # Language tokenizers with external dependencies ja = sudachipy>=0.4.9 @@ -94,8 +96,6 @@ ko = natto-py==0.9.0 th = pythainlp>=2.0 -zh = - spacy-pkuseg==0.0.26 [bdist_wheel] universal = false diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 1194438de..37983cb1a 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -143,6 +143,9 @@ nO = null @architectures = "spacy-transformers.TransformerListener.v1" grad_factor = 1.0 +[components.textcat.model.tok2vec.pooling] +@layers = "reduce_mean.v1" + [components.textcat.model.linear_model] @architectures = "spacy.TextCatBOW.v1" exclusive_classes = false diff --git a/spacy/lang/zh/__init__.py b/spacy/lang/zh/__init__.py index 30560ed0d..9a8a21a63 100644 --- a/spacy/lang/zh/__init__.py +++ b/spacy/lang/zh/__init__.py @@ -17,7 +17,7 @@ from ... import util # fmt: off -_PKUSEG_INSTALL_MSG = "install spacy-pkuseg with `pip install spacy-pkuseg==0.0.26`" +_PKUSEG_INSTALL_MSG = "install spacy-pkuseg with `pip install \"spacy-pkuseg>=0.0.27,<0.1.0\"` or `conda install -c conda-forge \"spacy-pkuseg>=0.0.27,<0.1.0\"`" # fmt: on DEFAULT_CONFIG = """ diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 181bbcf4c..2ec036810 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -61,14 +61,14 @@ def build_bow_text_classifier( @registry.architectures.register("spacy.TextCatEnsemble.v2") -def build_text_classifier( +def build_text_classifier_v2( tok2vec: Model[List[Doc], List[Floats2d]], linear_model: Model[List[Doc], Floats2d], nO: Optional[int] = None, ) -> Model[List[Doc], Floats2d]: exclusive_classes = not linear_model.attrs["multi_label"] with Model.define_operators({">>": chain, "|": concatenate}): - width = tok2vec.get_dim("nO") + width = tok2vec.maybe_get_dim("nO") cnn_model = ( tok2vec >> list2ragged() diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index a03c7daf0..305f8f5df 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -92,9 +92,6 @@ class Morphologizer(Tagger): # 2) labels_pos stores a mapping from morph+POS->POS cfg = {"labels_morph": labels_morph or {}, "labels_pos": labels_pos or {}} self.cfg = dict(sorted(cfg.items())) - # add mappings for empty morph - self.cfg["labels_morph"][Morphology.EMPTY_MORPH] = Morphology.EMPTY_MORPH - self.cfg["labels_pos"][Morphology.EMPTY_MORPH] = POS_IDS[""] @property def labels(self): @@ -201,8 +198,8 @@ class Morphologizer(Tagger): doc_tag_ids = doc_tag_ids.get() for j, tag_id in enumerate(doc_tag_ids): morph = self.labels[tag_id] - doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"][morph]) - doc.c[j].pos = self.cfg["labels_pos"][morph] + doc.c[j].morph = self.vocab.morphology.add(self.cfg["labels_morph"].get(morph, 0)) + doc.c[j].pos = self.cfg["labels_pos"].get(morph, 0) def get_loss(self, examples, scores): """Find the loss and gradient of loss for the batch of documents and @@ -228,12 +225,12 @@ class Morphologizer(Tagger): # doesn't, so if either is None, treat both as None here so that # truths doesn't end up with an unknown morph+POS combination if pos is None or morph is None: - pos = None - morph = None - label_dict = Morphology.feats_to_dict(morph) - if pos: - label_dict[self.POS_FEAT] = pos - label = self.vocab.strings[self.vocab.morphology.add(label_dict)] + label = None + else: + label_dict = Morphology.feats_to_dict(morph) + if pos: + label_dict[self.POS_FEAT] = pos + label = self.vocab.strings[self.vocab.morphology.add(label_dict)] eg_truths.append(label) truths.append(eg_truths) d_scores, loss = loss_func(scores, truths) diff --git a/spacy/scorer.py b/spacy/scorer.py index fe64c23ad..5cace8fda 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -512,7 +512,7 @@ class Scorer: negative_labels (Iterable[str]): The string values that refer to no annotation (e.g. "NIL") RETURNS (Dict[str, Any]): A dictionary containing the scores. - DOCS (TODO): https://nightly.spacy.io/api/scorer#score_links + DOCS: https://nightly.spacy.io/api/scorer#score_links """ f_per_type = {} for example in examples: diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py index 85d1d6c8b..add42e00a 100644 --- a/spacy/tests/pipeline/test_morphologizer.py +++ b/spacy/tests/pipeline/test_morphologizer.py @@ -116,3 +116,23 @@ def test_overfitting_IO(): no_batch_deps = [doc.to_array([MORPH]) for doc in [nlp(text) for text in texts]] assert_equal(batch_deps_1, batch_deps_2) assert_equal(batch_deps_1, no_batch_deps) + + # Test without POS + nlp.remove_pipe("morphologizer") + nlp.add_pipe("morphologizer") + for example in train_examples: + for token in example.reference: + token.pos_ = "" + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + assert losses["morphologizer"] < 0.00001 + + # Test the trained model + test_text = "I like blue ham" + doc = nlp(test_text) + gold_morphs = ["Feat=N", "Feat=V", "", ""] + gold_pos_tags = ["", "", "", ""] + assert [str(t.morph) for t in doc] == gold_morphs + assert [t.pos_ for t in doc] == gold_pos_tags diff --git a/website/docs/api/transformer.md b/website/docs/api/transformer.md index 5754d2238..e31c8ad2c 100644 --- a/website/docs/api/transformer.md +++ b/website/docs/api/transformer.md @@ -61,11 +61,11 @@ on the transformer architectures and their arguments and hyperparameters. > nlp.add_pipe("transformer", config=DEFAULT_CONFIG) > ``` -| Setting | Description | -| ----------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `max_batch_items` | Maximum size of a padded batch. Defaults to `4096`. ~~int~~ | -| `set_extra_annotations` | Function that takes a batch of `Doc` objects and transformer outputs to set additional annotations on the `Doc`. The `Doc._.transformer_data` attribute is set prior to calling the callback. Defaults to `null_annotation_setter` (no additional annotations). ~~Callable[[List[Doc], FullTransformerBatch], None]~~ | -| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) wrapping the transformer. Defaults to [TransformerModel](/api/architectures#TransformerModel). ~~Model[List[Doc], FullTransformerBatch]~~ | +| Setting | Description | +| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `max_batch_items` | Maximum size of a padded batch. Defaults to `4096`. ~~int~~ | +| `set_extra_annotations` | Function that takes a batch of `Doc` objects and transformer outputs to set additional annotations on the `Doc`. The `Doc._.trf_data` attribute is set prior to calling the callback. Defaults to `null_annotation_setter` (no additional annotations). ~~Callable[[List[Doc], FullTransformerBatch], None]~~ | +| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) wrapping the transformer. Defaults to [TransformerModel](/api/architectures#TransformerModel). ~~Model[List[Doc], FullTransformerBatch]~~ | ```python https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py diff --git a/website/docs/usage/index.md b/website/docs/usage/index.md index ccb59e937..8d40ee61e 100644 --- a/website/docs/usage/index.md +++ b/website/docs/usage/index.md @@ -174,6 +174,7 @@ $ source .env/bin/activate # activate virtual env $ export PYTHONPATH=`pwd` # set Python path to spaCy dir $ pip install -r requirements.txt # install all requirements $ python setup.py build_ext --inplace # compile spaCy +$ python setup.py install # install spaCy ``` Compared to regular install via pip, the diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md index aa62a77d4..641db02f5 100644 --- a/website/docs/usage/layers-architectures.md +++ b/website/docs/usage/layers-architectures.md @@ -843,6 +843,27 @@ def __call__(self, Doc doc): return doc ``` +There is one more optional method to implement: [`score`](/api/pipe#score) +calculates the performance of your component on a set of examples, and +returns the results as a dictionary: + +```python +### The score method +def score(self, examples: Iterable[Example]) -> Dict[str, Any]: + prf = PRFScore() + for example in examples: + ... + + return { + "rel_micro_p": prf.precision, + "rel_micro_r": prf.recall, + "rel_micro_f": prf.fscore, + } +``` + +This is particularly useful to see the scores on the development corpus +when training the component with [`spacy train`](/api/cli#training). + Once our `TrainablePipe` subclass is fully implemented, we can [register](/usage/processing-pipelines#custom-components-factories) the component with the [`@Language.factory`](/api/language#factory) decorator. This @@ -865,6 +886,11 @@ assigns it a name and lets you create the component with > [components.relation_extractor.model.get_candidates] > @misc = "rel_cand_generator.v1" > max_length = 20 +> +> [training.score_weights] +> rel_micro_p = 0.0 +> rel_micro_r = 0.0 +> rel_micro_f = 1.0 > ``` ```python @@ -876,6 +902,28 @@ def make_relation_extractor(nlp, name, model): return RelationExtractor(nlp.vocab, model, name) ``` +You can extend the decorator to include information such as the type of +annotations that are required for this component to run, the type of annotations +it produces, and the scores that can be calculated: + +```python +### Factory annotations {highlight="5-11"} +from spacy.language import Language + +@Language.factory( + "relation_extractor", + requires=["doc.ents", "token.ent_iob", "token.ent_type"], + assigns=["doc._.rel"], + default_score_weights={ + "rel_micro_p": None, + "rel_micro_r": None, + "rel_micro_f": None, + }, +) +def make_relation_extractor(nlp, name, model): + return RelationExtractor(nlp.vocab, model, name) +``` + diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index 274ea5989..58c846e9d 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -969,7 +969,7 @@ import spacy from spacy.tokens import Doc, DocBin nlp = spacy.blank("en") -docbin = DocBin(nlp.vocab) +docbin = DocBin() words = ["Apple", "is", "looking", "at", "buying", "U.K.", "startup", "."] spaces = [True, True, True, True, True, True, True, False] ents = ["B-ORG", "O", "O", "O", "O", "B-GPE", "O", "O"] diff --git a/website/src/widgets/quickstart-install.js b/website/src/widgets/quickstart-install.js index 37ae10da4..6bb14b687 100644 --- a/website/src/widgets/quickstart-install.js +++ b/website/src/widgets/quickstart-install.js @@ -7,7 +7,7 @@ import { repo } from '../components/util' const DEFAULT_MODELS = ['en'] const DEFAULT_OPT = 'efficiency' const DEFAULT_HARDWARE = 'cpu' -const DEFAULT_CUDA = 'cuda100' +const DEFAULT_CUDA = 'cuda102' const CUDA = { '8.0': 'cuda80', '9.0': 'cuda90', @@ -16,56 +16,9 @@ const CUDA = { '10.0': 'cuda100', '10.1': 'cuda101', '10.2': 'cuda102', + '11.0': 'cuda110', } -const LANG_EXTRAS = ['zh', 'ja'] // only for languages with models -const DATA = [ - { - id: 'os', - title: 'Operating system', - options: [ - { id: 'mac', title: 'macOS / OSX', checked: true }, - { id: 'windows', title: 'Windows' }, - { id: 'linux', title: 'Linux' }, - ], - }, - { - id: 'package', - title: 'Package manager', - options: [ - { id: 'pip', title: 'pip', checked: true }, - { id: 'conda', title: 'conda' }, - { id: 'source', title: 'from source' }, - ], - }, - { - id: 'hardware', - title: 'Hardware', - options: [ - { id: 'cpu', title: 'CPU', checked: DEFAULT_HARDWARE === 'cpu' }, - { id: 'gpu', title: 'GPU', checked: DEFAULT_HARDWARE == 'gpu' }, - ], - dropdown: Object.keys(CUDA).map(id => ({ id: CUDA[id], title: `CUDA ${id}` })), - defaultValue: DEFAULT_CUDA, - }, - { - id: 'config', - title: 'Configuration', - multiple: true, - options: [ - { - id: 'venv', - title: 'virtual env', - help: 'Use a virtual environment and install spaCy into a user directory', - }, - { - id: 'train', - title: 'train models', - help: - 'Check this if you plan to train your own models with spaCy to install extra dependencies and data resources', - }, - ], - }, -] +const LANG_EXTRAS = ['ja'] // only for languages with models const QuickstartInstall = ({ id, title }) => { const [train, setTrain] = useState(false) @@ -99,7 +52,56 @@ const QuickstartInstall = ({ id, title }) => { const pkg = nightly ? 'spacy-nightly' : 'spacy' const models = languages.filter(({ models }) => models !== null) const data = [ - ...DATA, + { + id: 'os', + title: 'Operating system', + options: [ + { id: 'mac', title: 'macOS / OSX', checked: true }, + { id: 'windows', title: 'Windows' }, + { id: 'linux', title: 'Linux' }, + ], + }, + { + id: 'package', + title: 'Package manager', + options: [ + { id: 'pip', title: 'pip', checked: true }, + !nightly ? { id: 'conda', title: 'conda' } : null, + { id: 'source', title: 'from source' }, + ].filter(o => o), + }, + { + id: 'hardware', + title: 'Hardware', + options: [ + { id: 'cpu', title: 'CPU', checked: DEFAULT_HARDWARE === 'cpu' }, + { id: 'gpu', title: 'GPU', checked: DEFAULT_HARDWARE == 'gpu' }, + ], + dropdown: Object.keys(CUDA).map(id => ({ + id: CUDA[id], + title: `CUDA ${id}`, + })), + defaultValue: DEFAULT_CUDA, + }, + { + id: 'config', + title: 'Configuration', + multiple: true, + options: [ + { + id: 'venv', + title: 'virtual env', + help: + 'Use a virtual environment and install spaCy into a user directory', + }, + { + id: 'train', + title: 'train models', + help: + 'Check this if you plan to train your own models with spaCy to install extra dependencies and data resources', + }, + ], + }, { id: 'models', title: 'Trained pipelines', @@ -141,11 +143,6 @@ const QuickstartInstall = ({ id, title }) => { setters={setters} showDropdown={showDropdown} > - {nightly && ( - - # 🚨 Nightly releases are currently only available via pip - - )} python -m venv .env source .env/bin/activate @@ -180,15 +177,17 @@ const QuickstartInstall = ({ id, title }) => { pip install -r requirements.txt python setup.py build_ext --inplace - {(train || hardware == 'gpu') && ( - pip install -e '.[{pipExtras}]' - )} - - - conda install -c conda-forge spacy-transformers + + pip install {train || hardware == 'gpu' ? `'.[${pipExtras}]'` : '.'} + + + # packages only available via pip - conda install -c conda-forge spacy-lookups-data + pip install spacy-transformers + + + pip install spacy-lookups-data {models.map(({ code, models: modelOptions }) => {