diff --git a/.github/azure-steps.yml b/.github/azure-steps.yml index 125b7de7d..85078268c 100644 --- a/.github/azure-steps.yml +++ b/.github/azure-steps.yml @@ -116,7 +116,7 @@ steps: displayName: "Run CPU tests" - script: | - python -m pip install --pre thinc-apple-ops + python -m pip install 'spacy[apple]' python -m pytest --pyargs spacy displayName: "Run CPU tests with thinc-apple-ops" condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.11')) diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml index c9833cdba..794adee85 100644 --- a/.github/workflows/lock.yml +++ b/.github/workflows/lock.yml @@ -15,11 +15,11 @@ jobs: action: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v3 + - uses: dessant/lock-threads@v4 with: process-only: 'issues' issue-inactive-days: '30' - issue-comment: > - This thread has been automatically locked since there - has not been any recent activity after it was closed. + issue-comment: > + This thread has been automatically locked since there + has not been any recent activity after it was closed. Please open a new issue for related bugs. diff --git a/README.md b/README.md index abfc3da67..195424551 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ parsing, **named entity recognition**, **text classification** and more, multi-task learning with pretrained **transformers** like BERT, as well as a production-ready [**training system**](https://spacy.io/usage/training) and easy model packaging, deployment and workflow management. spaCy is commercial -open-source software, released under the MIT license. +open-source software, released under the [MIT license](https://github.com/explosion/spaCy/blob/master/LICENSE). đź’« **Version 3.4 out now!** [Check out the release notes here.](https://github.com/explosion/spaCy/releases) @@ -46,6 +46,7 @@ open-source software, released under the MIT license. | đź›  **[Changelog]** | Changes and version history. | | đź’ť **[Contribute]** | How to contribute to the spaCy project and code base. | | spaCy Tailored Pipelines | Get a custom spaCy pipeline, tailor-made for your NLP problem by spaCy's core developers. Streamlined, production-ready, predictable and maintainable. Start by completing our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-pipelines)** | +| spaCy Tailored Pipelines | Bespoke advice for problem solving, strategy and analysis for applied NLP projects. Services include data strategy, code reviews, pipeline design and annotation coaching. Curious? Fill in our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-analysis)** | [spacy 101]: https://spacy.io/usage/spacy-101 [new in v3.0]: https://spacy.io/usage/v3 @@ -59,6 +60,7 @@ open-source software, released under the MIT license. [changelog]: https://spacy.io/usage#changelog [contribute]: https://github.com/explosion/spaCy/blob/master/CONTRIBUTING.md + ## đź’¬ Where to ask questions The spaCy project is maintained by the [spaCy team](https://explosion.ai/about). diff --git a/build-constraints.txt b/build-constraints.txt index 956973abf..c1e82f1b0 100644 --- a/build-constraints.txt +++ b/build-constraints.txt @@ -5,4 +5,5 @@ numpy==1.17.3; python_version=='3.8' and platform_machine!='aarch64' numpy==1.19.2; python_version=='3.8' and platform_machine=='aarch64' numpy==1.19.3; python_version=='3.9' numpy==1.21.3; python_version=='3.10' -numpy; python_version>='3.11' +numpy==1.23.2; python_version=='3.11' +numpy; python_version>='3.12' diff --git a/pyproject.toml b/pyproject.toml index 7abd7a96f..4b0da39b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "cymem>=2.0.2,<2.1.0", "preshed>=3.0.2,<3.1.0", "murmurhash>=0.28.0,<1.1.0", - "thinc>=8.1.0,<8.2.0", + "thinc>=9.0.0.dev0,<9.1.0", "numpy>=1.15.0", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 7435e9ec7..bea699ac7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ preshed>=3.0.2,<3.1.0 thinc==8.1.0 ml_datasets>=0.2.0,<0.3.0 murmurhash>=0.28.0,<1.1.0 -wasabi>=0.9.1,<1.1.0 +wasabi>=0.9.1,<1.2.0 srsly>=2.4.3,<3.0.0 catalogue>=2.0.6,<2.1.0 typer>=0.3.0,<0.8.0 diff --git a/setup.cfg b/setup.cfg index 3c1bf5b0b..5158a1086 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,8 +38,8 @@ install_requires = murmurhash>=0.28.0,<1.1.0 cymem>=2.0.2,<2.1.0 preshed>=3.0.2,<3.1.0 - thinc>=8.1.0,<8.2.0 - wasabi>=0.9.1,<1.1.0 + thinc>=9.0.0.dev0,<9.1.0 + wasabi>=0.9.1,<1.2.0 srsly>=2.4.3,<3.0.0 catalogue>=2.0.6,<2.1.0 # Third-party dependencies diff --git a/setup.py b/setup.py index bc911276f..77a4cf283 100755 --- a/setup.py +++ b/setup.py @@ -38,7 +38,6 @@ MOD_NAMES = [ "spacy.pipeline.dep_parser", "spacy.pipeline._edit_tree_internals.edit_trees", "spacy.pipeline.morphologizer", - "spacy.pipeline.multitask", "spacy.pipeline.ner", "spacy.pipeline.pipe", "spacy.pipeline.trainable_pipe", @@ -49,6 +48,7 @@ MOD_NAMES = [ "spacy.pipeline._parser_internals.arc_eager", "spacy.pipeline._parser_internals.ner", "spacy.pipeline._parser_internals.nonproj", + "spacy.pipeline._parser_internals.search", "spacy.pipeline._parser_internals._state", "spacy.pipeline._parser_internals.stateclass", "spacy.pipeline._parser_internals.transition_system", @@ -68,6 +68,7 @@ MOD_NAMES = [ "spacy.matcher.dependencymatcher", "spacy.symbols", "spacy.vectors", + "spacy.tests.parser._search", ] COMPILE_OPTIONS = { "msvc": ["/Ox", "/EHsc"], diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index aab2c8d12..aabd1cfef 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -16,6 +16,7 @@ from .debug_config import debug_config # noqa: F401 from .debug_model import debug_model # noqa: F401 from .debug_diff import debug_diff # noqa: F401 from .evaluate import evaluate # noqa: F401 +from .apply import apply # noqa: F401 from .convert import convert # noqa: F401 from .init_pipeline import init_pipeline_cli # noqa: F401 from .init_config import init_config, fill_config # noqa: F401 diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 7ce006108..c46abffe5 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -158,15 +158,15 @@ def load_project_config( sys.exit(1) validate_project_version(config) validate_project_commands(config) + if interpolate: + err = f"{PROJECT_FILE} validation error" + with show_validation_error(title=err, hint_fill=False): + config = substitute_project_variables(config, overrides) # Make sure directories defined in config exist for subdir in config.get("directories", []): dir_path = path / subdir if not dir_path.exists(): dir_path.mkdir(parents=True) - if interpolate: - err = f"{PROJECT_FILE} validation error" - with show_validation_error(title=err, hint_fill=False): - config = substitute_project_variables(config, overrides) return config @@ -582,6 +582,29 @@ def setup_gpu(use_gpu: int, silent=None) -> None: local_msg.info("To switch to GPU 0, use the option: --gpu-id 0") +def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]: + if not path.is_dir(): + return [path] + paths = [path] + locs = [] + seen = set() + for path in paths: + if str(path) in seen: + continue + seen.add(str(path)) + if path.parts[-1].startswith("."): + continue + elif path.is_dir(): + paths.extend(path.iterdir()) + elif suffix is not None and not path.parts[-1].endswith(suffix): + continue + else: + locs.append(path) + # It's good to sort these, in case the ordering messes up cache. + locs.sort() + return locs + + def _format_number(number: Union[int, float], ndigits: int = 2) -> str: """Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s, as happens with `round(number, ndigits)`""" diff --git a/spacy/cli/apply.py b/spacy/cli/apply.py new file mode 100644 index 000000000..9d170bc95 --- /dev/null +++ b/spacy/cli/apply.py @@ -0,0 +1,143 @@ +import tqdm +import srsly + +from itertools import chain +from pathlib import Path +from typing import Optional, List, Iterable, cast, Union + +from wasabi import msg + +from ._util import app, Arg, Opt, setup_gpu, import_code, walk_directory + +from ..tokens import Doc, DocBin +from ..vocab import Vocab +from ..util import ensure_path, load_model + + +path_help = """Location of the documents to predict on. +Can be a single file in .spacy format or a .jsonl file. +Files with other extensions are treated as single plain text documents. +If a directory is provided it is traversed recursively to grab +all files to be processed. +The files can be a mixture of .spacy, .jsonl and text files. +If .jsonl is provided the specified field is going +to be grabbed ("text" by default).""" + +out_help = "Path to save the resulting .spacy file" +code_help = ( + "Path to Python file with additional " "code (registered functions) to be imported" +) +gold_help = "Use gold preprocessing provided in the .spacy files" +force_msg = ( + "The provided output file already exists. " + "To force overwriting the output file, set the --force or -F flag." +) + + +DocOrStrStream = Union[Iterable[str], Iterable[Doc]] + + +def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]: + """ + Stream Doc objects from DocBin. + """ + docbin = DocBin().from_disk(path) + for doc in docbin.get_docs(vocab): + yield doc + + +def _stream_jsonl(path: Path, field: str) -> Iterable[str]: + """ + Stream "text" field from JSONL. If the field "text" is + not found it raises error. + """ + for entry in srsly.read_jsonl(path): + if field not in entry: + msg.fail( + f"{path} does not contain the required '{field}' field.", exits=1 + ) + else: + yield entry[field] + + +def _stream_texts(paths: Iterable[Path]) -> Iterable[str]: + """ + Yields strings from text files in paths. + """ + for path in paths: + with open(path, "r") as fin: + text = fin.read() + yield text + + +@app.command("apply") +def apply_cli( + # fmt: off + model: str = Arg(..., help="Model name or path"), + data_path: Path = Arg(..., help=path_help, exists=True), + output_file: Path = Arg(..., help=out_help, dir_okay=False), + code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help), + text_key: str = Opt("text", "--text-key", "-tk", help="Key containing text string for JSONL"), + force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"), + use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."), + batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."), + n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.") +): + """ + Apply a trained pipeline to documents to get predictions. + Expects a loadable spaCy pipeline and path to the data, which + can be a directory or a file. + The data files can be provided in multiple formats: + 1. .spacy files + 2. .jsonl files with a specified "field" to read the text from. + 3. Files with any other extension are assumed to be containing + a single document. + DOCS: https://spacy.io/api/cli#apply + """ + data_path = ensure_path(data_path) + output_file = ensure_path(output_file) + code_path = ensure_path(code_path) + if output_file.exists() and not force_overwrite: + msg.fail(force_msg, exits=1) + if not data_path.exists(): + msg.fail(f"Couldn't find data path: {data_path}", exits=1) + import_code(code_path) + setup_gpu(use_gpu) + apply(data_path, output_file, model, text_key, batch_size, n_process) + + +def apply( + data_path: Path, + output_file: Path, + model: str, + json_field: str, + batch_size: int, + n_process: int, +): + docbin = DocBin(store_user_data=True) + paths = walk_directory(data_path) + if len(paths) == 0: + docbin.to_disk(output_file) + msg.warn("Did not find data to process," + f" {data_path} seems to be an empty directory.") + return + nlp = load_model(model) + msg.good(f"Loaded model {model}") + vocab = nlp.vocab + streams: List[DocOrStrStream] = [] + text_files = [] + for path in paths: + if path.suffix == ".spacy": + streams.append(_stream_docbin(path, vocab)) + elif path.suffix == ".jsonl": + streams.append(_stream_jsonl(path, json_field)) + else: + text_files.append(path) + if len(text_files) > 0: + streams.append(_stream_texts(text_files)) + datagen = cast(DocOrStrStream, chain(*streams)) + for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)): + docbin.add(doc) + if output_file.suffix == "": + output_file = output_file.with_suffix(".spacy") + docbin.to_disk(output_file) diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index 04eb7078f..7f365ae2c 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Mapping, Optional, Any, List, Union +from typing import Callable, Iterable, Mapping, Optional, Any, Union from enum import Enum from pathlib import Path from wasabi import Printer @@ -7,7 +7,7 @@ import re import sys import itertools -from ._util import app, Arg, Opt +from ._util import app, Arg, Opt, walk_directory from ..training import docs_to_json from ..tokens import Doc, DocBin from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs @@ -189,33 +189,6 @@ def autodetect_ner_format(input_data: str) -> Optional[str]: return None -def walk_directory(path: Path, converter: str) -> List[Path]: - if not path.is_dir(): - return [path] - paths = [path] - locs = [] - seen = set() - for path in paths: - if str(path) in seen: - continue - seen.add(str(path)) - if path.parts[-1].startswith("."): - continue - elif path.is_dir(): - paths.extend(path.iterdir()) - elif converter == "json" and not path.parts[-1].endswith("json"): - continue - elif converter == "conll" and not path.parts[-1].endswith("conll"): - continue - elif converter == "iob" and not path.parts[-1].endswith("iob"): - continue - else: - locs.append(path) - # It's good to sort these, in case the ordering messes up cache. - locs.sort() - return locs - - def verify_cli_args( msg: Printer, input_path: Path, diff --git a/spacy/cli/download.py b/spacy/cli/download.py index 0c9a32b93..4c998a6e0 100644 --- a/spacy/cli/download.py +++ b/spacy/cli/download.py @@ -8,7 +8,6 @@ from ._util import app, Arg, Opt, WHEEL_SUFFIX, SDIST_SUFFIX from .. import about from ..util import is_package, get_minor_version, run_command from ..util import is_prerelease_version -from ..errors import OLD_MODEL_SHORTCUTS @app.command( @@ -61,12 +60,6 @@ def download( version = components[-1] else: model_name = model - if model in OLD_MODEL_SHORTCUTS: - msg.warn( - f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. Please " - f"use the full pipeline package name '{OLD_MODEL_SHORTCUTS[model]}' instead." - ) - model_name = OLD_MODEL_SHORTCUTS[model] compatibility = get_compatibility() version = get_version(model_name, compatibility) diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index a109c4a5a..6dd174902 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -101,8 +101,8 @@ def project_run( if not (project_dir / dep).exists(): err = f"Missing dependency specified by command '{subcommand}': {dep}" err_help = "Maybe you forgot to run the 'project assets' command or a previous step?" - err_kwargs = {"exits": 1} if not dry else {} - msg.fail(err, err_help, **err_kwargs) + err_exits = 1 if not dry else None + msg.fail(err, err_help, exits=err_exits) check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION) with working_dir(project_dir) as current_dir: msg.divider(subcommand) diff --git a/spacy/errors.py b/spacy/errors.py index 82e7c52bc..e800be1fa 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -131,13 +131,6 @@ class Warnings(metaclass=ErrorsWithCodes): "and make it independent. For example, `replace_listeners = " "[\"model.tok2vec\"]` See the documentation for details: " "https://spacy.io/usage/training#config-components-listeners") - W088 = ("The pipeline component {name} implements a `begin_training` " - "method, which won't be called by spaCy. As of v3.0, `begin_training` " - "has been renamed to `initialize`, so you likely want to rename the " - "component method. See the documentation for details: " - "https://spacy.io/api/language#initialize") - W089 = ("As of spaCy v3.0, the `nlp.begin_training` method has been renamed " - "to `nlp.initialize`.") W090 = ("Could not locate any {format} files in path '{path}'.") W091 = ("Could not clean/remove the temp directory at {dir}: {msg}.") W092 = ("Ignoring annotations for sentence starts, as dependency heads are set.") @@ -250,8 +243,6 @@ class Errors(metaclass=ErrorsWithCodes): "https://spacy.io/usage/models") E011 = ("Unknown operator: '{op}'. Options: {opts}") E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}") - E016 = ("MultitaskObjective target should be function or one of: dep, " - "tag, ent, dep_tag_offset, ent_tag.") E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}") E018 = ("Can't retrieve string for hash '{hash_value}'. This usually " "refers to an issue with the `Vocab` or `StringStore`.") @@ -345,6 +336,11 @@ class Errors(metaclass=ErrorsWithCodes): "clear the existing vectors and resize the table.") E074 = ("Error interpreting compiled match pattern: patterns are expected " "to end with the attribute {attr}. Got: {bad_attr}.") + E079 = ("Error computing states in beam: number of predicted beams " + "({pbeams}) does not equal number of gold beams ({gbeams}).") + E080 = ("Duplicate state found in beam: {key}.") + E081 = ("Error getting gradient in beam: number of histories ({n_hist}) " + "does not equal number of losses ({losses}).") E082 = ("Error deprojectivizing parse: number of heads ({n_heads}), " "projective heads ({n_proj_heads}) and labels ({n_labels}) do not " "match.") @@ -727,13 +723,6 @@ class Errors(metaclass=ErrorsWithCodes): "method in component '{name}'. If you want to use this " "method, make sure it's overwritten on the subclass.") E940 = ("Found NaN values in scores.") - E941 = ("Can't find model '{name}'. It looks like you're trying to load a " - "model from a shortcut, which is obsolete as of spaCy v3.0. To " - "load the model, use its full name instead:\n\n" - "nlp = spacy.load(\"{full}\")\n\nFor more details on the available " - "models, see the models directory: https://spacy.io/models. If you " - "want to create a blank model, use spacy.blank: " - "nlp = spacy.blank(\"{name}\")") E942 = ("Executing `after_{name}` callback failed. Expected the function to " "return an initialized nlp object but got: {value}. Maybe " "you forgot to return the modified object in your function?") @@ -962,15 +951,6 @@ class Errors(metaclass=ErrorsWithCodes): "but got '{received_type}'") -# Deprecated model shortcuts, only used in errors and warnings -OLD_MODEL_SHORTCUTS = { - "en": "en_core_web_sm", "de": "de_core_news_sm", "es": "es_core_news_sm", - "pt": "pt_core_news_sm", "fr": "fr_core_news_sm", "it": "it_core_news_sm", - "nl": "nl_core_news_sm", "el": "el_core_news_sm", "nb": "nb_core_news_sm", - "lt": "lt_core_news_sm", "xx": "xx_ent_wiki_sm" -} - - # fmt: on diff --git a/spacy/lang/nl/stop_words.py b/spacy/lang/nl/stop_words.py index a2c6198e7..cd4fdefdf 100644 --- a/spacy/lang/nl/stop_words.py +++ b/spacy/lang/nl/stop_words.py @@ -15,7 +15,7 @@ STOP_WORDS = set( """ -aan af al alle alles allebei alleen allen als altijd ander anders andere anderen aangaangde aangezien achter achterna +aan af al alle alles allebei alleen allen als altijd ander anders andere anderen aangaande aangezien achter achterna afgelopen aldus alhoewel anderzijds ben bij bijna bijvoorbeeld behalve beide beiden beneden bent bepaald beter betere betreffende binnen binnenin boven diff --git a/spacy/language.py b/spacy/language.py index e0abfd5e7..dcb62aef0 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1239,15 +1239,6 @@ class Language: sgd(key, W, dW) # type: ignore[call-arg, misc] return losses - def begin_training( - self, - get_examples: Optional[Callable[[], Iterable[Example]]] = None, - *, - sgd: Optional[Optimizer] = None, - ) -> Optimizer: - warnings.warn(Warnings.W089, DeprecationWarning) - return self.initialize(get_examples, sgd=sgd) - def initialize( self, get_examples: Optional[Callable[[], Iterable[Example]]] = None, diff --git a/spacy/ml/parser_model.pyx b/spacy/ml/parser_model.pyx index 055fa0bad..91558683b 100644 --- a/spacy/ml/parser_model.pyx +++ b/spacy/ml/parser_model.pyx @@ -3,7 +3,6 @@ cimport numpy as np from libc.math cimport exp from libc.string cimport memset, memcpy from libc.stdlib cimport calloc, free, realloc -from thinc.backends.linalg cimport Vec, VecVec from thinc.backends.cblas cimport saxpy, sgemm import numpy @@ -102,11 +101,10 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, sum_state_features(cblas, A.unmaxed, W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces) for i in range(n.states): - VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces], - W.feat_bias, 1., n.hiddens * n.pieces) + saxpy(cblas)(n.hiddens * n.pieces, 1., W.feat_bias, 1, &A.unmaxed[i*n.hiddens*n.pieces], 1) for j in range(n.hiddens): index = i * n.hiddens * n.pieces + j * n.pieces - which = Vec.arg_max(&A.unmaxed[index], n.pieces) + which = _arg_max(&A.unmaxed[index], n.pieces) A.hiddens[i*n.hiddens + j] = A.unmaxed[index + which] memset(A.scores, 0, n.states * n.classes * sizeof(float)) if W.hidden_weights == NULL: @@ -119,8 +117,7 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states, 0.0, A.scores, n.classes) # Add bias for i in range(n.states): - VecVec.add_i(&A.scores[i*n.classes], - W.hidden_bias, 1., n.classes) + saxpy(cblas)(n.classes, 1., W.hidden_bias, 1, &A.scores[i*n.classes], 1) # Set unseen classes to minimum value i = 0 min_ = A.scores[0] @@ -158,7 +155,8 @@ cdef void cpu_log_loss(float* d_scores, """Do multi-label log loss""" cdef double max_, gmax, Z, gZ best = arg_max_if_gold(scores, costs, is_valid, O) - guess = Vec.arg_max(scores, O) + guess = _arg_max(scores, O) + if best == -1 or guess == -1: # These shouldn't happen, but if they do, we want to make sure we don't # cause an OOB access. @@ -488,3 +486,15 @@ cdef class precompute_hiddens: return d_best.reshape((d_best.shape + (1,))) return state_vector, backprop_relu + +cdef inline int _arg_max(const float* scores, const int n_classes) nogil: + if n_classes == 2: + return 0 if scores[0] > scores[1] else 1 + cdef int i + cdef int best = 0 + cdef float mode = scores[0] + for i in range(1, n_classes): + if scores[i] > mode: + mode = scores[i] + best = i + return best diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pxd b/spacy/pipeline/_parser_internals/_beam_utils.pxd index de3573fbc..571f246b1 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pxd +++ b/spacy/pipeline/_parser_internals/_beam_utils.pxd @@ -1,6 +1,6 @@ from ...typedefs cimport class_t, hash_t -# These are passed as callbacks to thinc.search.Beam +# These are passed as callbacks to .search.Beam cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1 cdef int check_final_state(void* _state, void* extra_args) except -1 diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx index fa7df2056..610c8ddee 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pyx +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -3,17 +3,16 @@ cimport numpy as np import numpy from cpython.ref cimport PyObject, Py_XDECREF -from thinc.extra.search cimport Beam -from thinc.extra.search import MaxViolation -from thinc.extra.search cimport MaxViolation from ...typedefs cimport hash_t, class_t from .transition_system cimport TransitionSystem, Transition from ...errors import Errors +from .search cimport Beam, MaxViolation +from .search import MaxViolation from .stateclass cimport StateC, StateClass -# These are passed as callbacks to thinc.search.Beam +# These are passed as callbacks to .search.Beam cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1: dest = _dest src = _src diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 257b5ef8a..a79aef64a 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -15,7 +15,7 @@ from ...training.example cimport Example from .stateclass cimport StateClass from ._state cimport StateC, ArcC from ...errors import Errors -from thinc.extra.search cimport Beam +from .search cimport Beam cdef weight_t MIN_SCORE = -90000 cdef attr_t SUBTOK_LABEL = hash_string('subtok') diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index cc196d85a..53ed03523 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -6,7 +6,6 @@ from libcpp.vector cimport vector from cymem.cymem cimport Pool from collections import Counter -from thinc.extra.search cimport Beam from ...tokens.doc cimport Doc from ...tokens.span import Span @@ -17,6 +16,7 @@ from ...attrs cimport IS_SPACE from ...structs cimport TokenC, SpanC from ...training import split_bilu_label from ...training.example cimport Example +from .search cimport Beam from .stateclass cimport StateClass from ._state cimport StateC from .transition_system cimport Transition, do_func_t diff --git a/spacy/pipeline/_parser_internals/search.pxd b/spacy/pipeline/_parser_internals/search.pxd new file mode 100644 index 000000000..dfe30e1c1 --- /dev/null +++ b/spacy/pipeline/_parser_internals/search.pxd @@ -0,0 +1,89 @@ +from cymem.cymem cimport Pool + +from libc.stdint cimport uint32_t +from libc.stdint cimport uint64_t +from libcpp.pair cimport pair +from libcpp.queue cimport priority_queue +from libcpp.vector cimport vector + +from ...typedefs cimport class_t, weight_t, hash_t + +ctypedef pair[weight_t, size_t] Entry +ctypedef priority_queue[Entry] Queue + + +ctypedef int (*trans_func_t)(void* dest, void* src, class_t clas, void* x) except -1 + +ctypedef void* (*init_func_t)(Pool mem, int n, void* extra_args) except NULL + +ctypedef int (*del_func_t)(Pool mem, void* state, void* extra_args) except -1 + +ctypedef int (*finish_func_t)(void* state, void* extra_args) except -1 + +ctypedef hash_t (*hash_func_t)(void* state, void* x) except 0 + + +cdef struct _State: + void* content + class_t* hist + weight_t score + weight_t loss + int i + int t + bint is_done + + +cdef class Beam: + cdef Pool mem + cdef class_t nr_class + cdef class_t width + cdef class_t size + cdef public weight_t min_density + cdef int t + cdef readonly bint is_done + cdef list histories + cdef list _parent_histories + cdef weight_t** scores + cdef int** is_valid + cdef weight_t** costs + cdef _State* _parents + cdef _State* _states + cdef del_func_t del_func + + cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1 + + cdef inline void* at(self, int i) nogil: + return self._states[i].content + + cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1 + cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, + void* extra_args) except -1 + cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1 + + + cdef inline void set_cell(self, int i, int j, weight_t score, int is_valid, weight_t cost) nogil: + self.scores[i][j] = score + self.is_valid[i][j] = is_valid + self.costs[i][j] = cost + + cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, + const weight_t* costs) except -1 + cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1 + + +cdef class MaxViolation: + cdef Pool mem + cdef weight_t cost + cdef weight_t delta + cdef readonly weight_t p_score + cdef readonly weight_t g_score + cdef readonly double Z + cdef readonly double gZ + cdef class_t n + cdef readonly list p_hist + cdef readonly list g_hist + cdef readonly list p_probs + cdef readonly list g_probs + + cpdef int check(self, Beam pred, Beam gold) except -1 + cpdef int check_crf(self, Beam pred, Beam gold) except -1 diff --git a/spacy/pipeline/_parser_internals/search.pyx b/spacy/pipeline/_parser_internals/search.pyx new file mode 100644 index 000000000..1d9b6dd7a --- /dev/null +++ b/spacy/pipeline/_parser_internals/search.pyx @@ -0,0 +1,306 @@ +# cython: profile=True, experimental_cpp_class_def=True, cdivision=True, infer_types=True +cimport cython +from libc.string cimport memset, memcpy +from libc.math cimport log, exp +import math + +from cymem.cymem cimport Pool +from preshed.maps cimport PreshMap + + +cdef class Beam: + def __init__(self, class_t nr_class, class_t width, weight_t min_density=0.0): + assert nr_class != 0 + assert width != 0 + self.nr_class = nr_class + self.width = width + self.min_density = min_density + self.size = 1 + self.t = 0 + self.mem = Pool() + self.del_func = NULL + self._parents = <_State*>self.mem.alloc(self.width, sizeof(_State)) + self._states = <_State*>self.mem.alloc(self.width, sizeof(_State)) + cdef int i + self.histories = [[] for i in range(self.width)] + self._parent_histories = [[] for i in range(self.width)] + + self.scores = self.mem.alloc(self.width, sizeof(weight_t*)) + self.is_valid = self.mem.alloc(self.width, sizeof(weight_t*)) + self.costs = self.mem.alloc(self.width, sizeof(weight_t*)) + for i in range(self.width): + self.scores[i] = self.mem.alloc(self.nr_class, sizeof(weight_t)) + self.is_valid[i] = self.mem.alloc(self.nr_class, sizeof(int)) + self.costs[i] = self.mem.alloc(self.nr_class, sizeof(weight_t)) + + def __len__(self): + return self.size + + property score: + def __get__(self): + return self._states[0].score + + property min_score: + def __get__(self): + return self._states[self.size-1].score + + property loss: + def __get__(self): + return self._states[0].loss + + property probs: + def __get__(self): + return _softmax([self._states[i].score for i in range(self.size)]) + + property scores: + def __get__(self): + return [self._states[i].score for i in range(self.size)] + + property histories: + def __get__(self): + return self.histories + + cdef int set_row(self, int i, const weight_t* scores, const int* is_valid, + const weight_t* costs) except -1: + cdef int j + for j in range(self.nr_class): + self.scores[i][j] = scores[j] + self.is_valid[i][j] = is_valid[j] + self.costs[i][j] = costs[j] + + cdef int set_table(self, weight_t** scores, int** is_valid, weight_t** costs) except -1: + cdef int i, j + for i in range(self.width): + memcpy(self.scores[i], scores[i], sizeof(weight_t) * self.nr_class) + memcpy(self.is_valid[i], is_valid[i], sizeof(bint) * self.nr_class) + memcpy(self.costs[i], costs[i], sizeof(int) * self.nr_class) + + cdef int initialize(self, init_func_t init_func, del_func_t del_func, int n, void* extra_args) except -1: + for i in range(self.width): + self._states[i].content = init_func(self.mem, n, extra_args) + self._parents[i].content = init_func(self.mem, n, extra_args) + self.del_func = del_func + + def __dealloc__(self): + if self.del_func == NULL: + return + + for i in range(self.width): + self.del_func(self.mem, self._states[i].content, NULL) + self.del_func(self.mem, self._parents[i].content, NULL) + + @cython.cdivision(True) + cdef int advance(self, trans_func_t transition_func, hash_func_t hash_func, + void* extra_args) except -1: + cdef weight_t** scores = self.scores + cdef int** is_valid = self.is_valid + cdef weight_t** costs = self.costs + + cdef Queue* q = new Queue() + self._fill(q, scores, is_valid) + # For a beam of width k, we only ever need 2k state objects. How? + # Each transition takes a parent and a class and produces a new state. + # So, we don't need the whole history --- just the parent. So at + # each step, we take a parent, and apply one or more extensions to + # it. + self._parents, self._states = self._states, self._parents + self._parent_histories, self.histories = self.histories, self._parent_histories + cdef weight_t score + cdef int p_i + cdef int i = 0 + cdef class_t clas + cdef _State* parent + cdef _State* state + cdef hash_t key + cdef PreshMap seen_states = PreshMap(self.width) + cdef uint64_t is_seen + cdef uint64_t one = 1 + while i < self.width and not q.empty(): + data = q.top() + p_i = data.second / self.nr_class + clas = data.second % self.nr_class + score = data.first + q.pop() + parent = &self._parents[p_i] + # Indicates terminal state reached; i.e. state is done + if parent.is_done: + # Now parent will not be changed, so we don't have to copy. + # Once finished, should also be unbranching. + self._states[i], parent[0] = parent[0], self._states[i] + parent.i = self._states[i].i + parent.t = self._states[i].t + parent.is_done = self._states[i].t + self._states[i].score = score + self.histories[i] = list(self._parent_histories[p_i]) + i += 1 + else: + state = &self._states[i] + # The supplied transition function should adjust the destination + # state to be the result of applying the class to the source state + transition_func(state.content, parent.content, clas, extra_args) + key = hash_func(state.content, extra_args) if hash_func is not NULL else 0 + is_seen = seen_states.get(key) + if key == 0 or key == 1 or not is_seen: + if key != 0 and key != 1: + seen_states.set(key, one) + state.score = score + state.loss = parent.loss + costs[p_i][clas] + self.histories[i] = list(self._parent_histories[p_i]) + self.histories[i].append(clas) + i += 1 + del q + self.size = i + assert self.size >= 1 + for i in range(self.width): + memset(self.scores[i], 0, sizeof(weight_t) * self.nr_class) + memset(self.costs[i], 0, sizeof(weight_t) * self.nr_class) + memset(self.is_valid[i], 0, sizeof(int) * self.nr_class) + self.t += 1 + + cdef int check_done(self, finish_func_t finish_func, void* extra_args) except -1: + cdef int i + for i in range(self.size): + if not self._states[i].is_done: + self._states[i].is_done = finish_func(self._states[i].content, extra_args) + for i in range(self.size): + if not self._states[i].is_done: + self.is_done = False + break + else: + self.is_done = True + + @cython.cdivision(True) + cdef int _fill(self, Queue* q, weight_t** scores, int** is_valid) except -1: + """Populate the queue from a k * n matrix of scores, where k is the + beam-width, and n is the number of classes. + """ + cdef Entry entry + cdef weight_t score + cdef _State* s + cdef int i, j, move_id + assert self.size >= 1 + cdef vector[Entry] entries + for i in range(self.size): + s = &self._states[i] + move_id = i * self.nr_class + if s.is_done: + # Update score by path average, following TACL '13 paper. + if self.histories[i]: + entry.first = s.score + (s.score / self.t) + else: + entry.first = s.score + entry.second = move_id + entries.push_back(entry) + else: + for j in range(self.nr_class): + if is_valid[i][j]: + entry.first = s.score + scores[i][j] + entry.second = move_id + j + entries.push_back(entry) + cdef double max_, Z, cutoff + if self.min_density == 0.0: + for i in range(entries.size()): + q.push(entries[i]) + elif not entries.empty(): + max_ = entries[0].first + Z = 0. + cutoff = 0. + # Softmax into probabilities, so we can prune + for i in range(entries.size()): + if entries[i].first > max_: + max_ = entries[i].first + for i in range(entries.size()): + Z += exp(entries[i].first-max_) + cutoff = (1. / Z) * self.min_density + for i in range(entries.size()): + prob = exp(entries[i].first-max_) / Z + if prob >= cutoff: + q.push(entries[i]) + + +cdef class MaxViolation: + def __init__(self): + self.p_score = 0.0 + self.g_score = 0.0 + self.Z = 0.0 + self.gZ = 0.0 + self.delta = -1 + self.cost = 0 + self.p_hist = [] + self.g_hist = [] + self.p_probs = [] + self.g_probs = [] + + cpdef int check(self, Beam pred, Beam gold) except -1: + cdef _State* p = &pred._states[0] + cdef _State* g = &gold._states[0] + cdef weight_t d = p.score - g.score + if p.loss >= 1 and (self.cost == 0 or d > self.delta): + self.cost = p.loss + self.delta = d + self.p_hist = list(pred.histories[0]) + self.g_hist = list(gold.histories[0]) + self.p_score = p.score + self.g_score = g.score + self.Z = 1e-10 + self.gZ = 1e-10 + for i in range(pred.size): + if pred._states[i].loss > 0: + self.Z += exp(pred._states[i].score) + for i in range(gold.size): + if gold._states[i].loss == 0: + prob = exp(gold._states[i].score) + self.Z += prob + self.gZ += prob + + cpdef int check_crf(self, Beam pred, Beam gold) except -1: + d = pred.score - gold.score + seen_golds = set([tuple(gold.histories[i]) for i in range(gold.size)]) + if pred.loss > 0 and (self.cost == 0 or d > self.delta): + p_hist = [] + p_scores = [] + g_hist = [] + g_scores = [] + for i in range(pred.size): + if pred._states[i].loss > 0: + p_scores.append(pred._states[i].score) + p_hist.append(list(pred.histories[i])) + # This can happen from non-monotonic actions + # If we find a better gold analysis this way, be sure to keep it. + elif pred._states[i].loss <= 0 \ + and tuple(pred.histories[i]) not in seen_golds: + g_scores.append(pred._states[i].score) + g_hist.append(list(pred.histories[i])) + for i in range(gold.size): + if gold._states[i].loss == 0: + g_scores.append(gold._states[i].score) + g_hist.append(list(gold.histories[i])) + + all_probs = _softmax(p_scores + g_scores) + p_probs = all_probs[:len(p_scores)] + g_probs_all = all_probs[len(p_scores):] + g_probs = _softmax(g_scores) + + self.cost = pred.loss + self.delta = d + self.p_hist = p_hist + self.g_hist = g_hist + # TODO: These variables are misnamed! These are the gradients of the loss. + self.p_probs = p_probs + # Intuition here: + # The gradient of the loss is: + # P(model) - P(truth) + # Normally, P(truth) is 1 for the gold + # But, if we want to do the "partial credit" scheme, we want + # to create a distribution over the gold, proportional to the scores + # awarded. + self.g_probs = [x-y for x, y in zip(g_probs_all, g_probs)] + + +def _softmax(nums): + if not nums: + return [] + max_ = max(nums) + nums = [(exp(n-max_) if n is not None else None) for n in nums] + Z = sum(n for n in nums if n is not None) + return [(n/Z if n is not None else None) for n in nums] diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 9676e2194..2a2242aa4 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -5,8 +5,9 @@ from itertools import islice import numpy as np import srsly -from thinc.api import Config, Model, SequenceCategoricalCrossentropy +from thinc.api import Config, Model from thinc.types import ArrayXd, Floats2d, Ints1d +from thinc.legacy import LegacySequenceCategoricalCrossentropy from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.schemas import validate_edit_tree @@ -129,7 +130,9 @@ class EditTreeLemmatizer(TrainablePipe): self, examples: Iterable[Example], scores: List[Floats2d] ) -> Tuple[float, List[Floats2d]]: validate_examples(examples, "EditTreeLemmatizer.get_loss") - loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) + loss_func = LegacySequenceCategoricalCrossentropy( + normalize=False, missing_value=-1 + ) truths = [] for eg in examples: @@ -347,9 +350,9 @@ class EditTreeLemmatizer(TrainablePipe): tree = dict(tree) if "orig" in tree: - tree["orig"] = self.vocab.strings[tree["orig"]] + tree["orig"] = self.vocab.strings.add(tree["orig"]) if "orig" in tree: - tree["subst"] = self.vocab.strings[tree["subst"]] + tree["subst"] = self.vocab.strings.add(tree["subst"]) trees.append(tree) diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 782a1dabe..293add9e1 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -1,7 +1,8 @@ # cython: infer_types=True, profile=True, binding=True from typing import Callable, Dict, Iterable, List, Optional, Union import srsly -from thinc.api import SequenceCategoricalCrossentropy, Model, Config +from thinc.api import Model, Config +from thinc.legacy import LegacySequenceCategoricalCrossentropy from thinc.types import Floats2d, Ints1d from itertools import islice @@ -290,7 +291,7 @@ class Morphologizer(Tagger): DOCS: https://spacy.io/api/morphologizer#get_loss """ validate_examples(examples, "Morphologizer.get_loss") - loss_func = SequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False) + loss_func = LegacySequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False) truths = [] for eg in examples: eg_truths = [] diff --git a/spacy/pipeline/multitask.pyx b/spacy/pipeline/multitask.pyx deleted file mode 100644 index 8c44061e2..000000000 --- a/spacy/pipeline/multitask.pyx +++ /dev/null @@ -1,221 +0,0 @@ -# cython: infer_types=True, profile=True, binding=True -from typing import Optional -import numpy -from thinc.api import CosineDistance, to_categorical, Model, Config -from thinc.api import set_dropout_rate - -from ..tokens.doc cimport Doc - -from .trainable_pipe import TrainablePipe -from .tagger import Tagger -from ..training import validate_examples -from ..language import Language -from ._parser_internals import nonproj -from ..attrs import POS, ID -from ..errors import Errors - - -default_model_config = """ -[model] -@architectures = "spacy.MultiTask.v1" -maxout_pieces = 3 -token_vector_width = 96 - -[model.tok2vec] -@architectures = "spacy.HashEmbedCNN.v2" -pretrained_vectors = null -width = 96 -depth = 4 -embed_size = 2000 -window_size = 1 -maxout_pieces = 2 -subword_features = true -""" -DEFAULT_MT_MODEL = Config().from_str(default_model_config)["model"] - - -@Language.factory( - "nn_labeller", - default_config={"labels": None, "target": "dep_tag_offset", "model": DEFAULT_MT_MODEL} -) -def make_nn_labeller(nlp: Language, name: str, model: Model, labels: Optional[dict], target: str): - return MultitaskObjective(nlp.vocab, model, name) - - -class MultitaskObjective(Tagger): - """Experimental: Assist training of a parser or tagger, by training a - side-objective. - """ - - def __init__(self, vocab, model, name="nn_labeller", *, target): - self.vocab = vocab - self.model = model - self.name = name - if target == "dep": - self.make_label = self.make_dep - elif target == "tag": - self.make_label = self.make_tag - elif target == "ent": - self.make_label = self.make_ent - elif target == "dep_tag_offset": - self.make_label = self.make_dep_tag_offset - elif target == "ent_tag": - self.make_label = self.make_ent_tag - elif target == "sent_start": - self.make_label = self.make_sent_start - elif hasattr(target, "__call__"): - self.make_label = target - else: - raise ValueError(Errors.E016) - cfg = {"labels": {}, "target": target} - self.cfg = dict(cfg) - - @property - def labels(self): - return self.cfg.setdefault("labels", {}) - - @labels.setter - def labels(self, value): - self.cfg["labels"] = value - - def set_annotations(self, docs, dep_ids): - pass - - def initialize(self, get_examples, nlp=None, labels=None): - if not hasattr(get_examples, "__call__"): - err = Errors.E930.format(name="MultitaskObjective", obj=type(get_examples)) - raise ValueError(err) - if labels is not None: - self.labels = labels - else: - for example in get_examples(): - for token in example.y: - label = self.make_label(token) - if label is not None and label not in self.labels: - self.labels[label] = len(self.labels) - self.model.initialize() # TODO: fix initialization by defining X and Y - - def predict(self, docs): - tokvecs = self.model.get_ref("tok2vec")(docs) - scores = self.model.get_ref("softmax")(tokvecs) - return tokvecs, scores - - def get_loss(self, examples, scores): - cdef int idx = 0 - correct = numpy.zeros((scores.shape[0],), dtype="i") - guesses = scores.argmax(axis=1) - docs = [eg.predicted for eg in examples] - for i, eg in enumerate(examples): - # Handles alignment for tokenization differences - doc_annots = eg.get_aligned() # TODO - for j in range(len(eg.predicted)): - tok_annots = {key: values[j] for key, values in tok_annots.items()} - label = self.make_label(j, tok_annots) - if label is None or label not in self.labels: - correct[idx] = guesses[idx] - else: - correct[idx] = self.labels[label] - idx += 1 - correct = self.model.ops.xp.array(correct, dtype="i") - d_scores = scores - to_categorical(correct, n_classes=scores.shape[1]) - loss = (d_scores**2).sum() - return float(loss), d_scores - - @staticmethod - def make_dep(token): - return token.dep_ - - @staticmethod - def make_tag(token): - return token.tag_ - - @staticmethod - def make_ent(token): - if token.ent_iob_ == "O": - return "O" - else: - return token.ent_iob_ + "-" + token.ent_type_ - - @staticmethod - def make_dep_tag_offset(token): - dep = token.dep_ - tag = token.tag_ - offset = token.head.i - token.i - offset = min(offset, 2) - offset = max(offset, -2) - return f"{dep}-{tag}:{offset}" - - @staticmethod - def make_ent_tag(token): - if token.ent_iob_ == "O": - ent = "O" - else: - ent = token.ent_iob_ + "-" + token.ent_type_ - tag = token.tag_ - return f"{tag}-{ent}" - - @staticmethod - def make_sent_start(token): - """A multi-task objective for representing sentence boundaries, - using BILU scheme. (O is impossible) - """ - if token.is_sent_start and token.is_sent_end: - return "U-SENT" - elif token.is_sent_start: - return "B-SENT" - else: - return "I-SENT" - - -class ClozeMultitask(TrainablePipe): - def __init__(self, vocab, model, **cfg): - self.vocab = vocab - self.model = model - self.cfg = cfg - self.distance = CosineDistance(ignore_zeros=True, normalize=False) # TODO: in config - - def set_annotations(self, docs, dep_ids): - pass - - def initialize(self, get_examples, nlp=None): - self.model.initialize() # TODO: fix initialization by defining X and Y - X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO"))) - self.model.output_layer.initialize(X) - - def predict(self, docs): - tokvecs = self.model.get_ref("tok2vec")(docs) - vectors = self.model.get_ref("output_layer")(tokvecs) - return tokvecs, vectors - - def get_loss(self, examples, vectors, prediction): - validate_examples(examples, "ClozeMultitask.get_loss") - # The simplest way to implement this would be to vstack the - # token.vector values, but that's a bit inefficient, especially on GPU. - # Instead we fetch the index into the vectors table for each of our tokens, - # and look them up all at once. This prevents data copying. - ids = self.model.ops.flatten([eg.predicted.to_array(ID).ravel() for eg in examples]) - target = vectors[ids] - gradient = self.distance.get_grad(prediction, target) - loss = self.distance.get_loss(prediction, target) - return float(loss), gradient - - def update(self, examples, *, drop=0., sgd=None, losses=None): - pass - - def rehearse(self, examples, drop=0., sgd=None, losses=None): - if losses is not None and self.name not in losses: - losses[self.name] = 0. - set_dropout_rate(self.model, drop) - validate_examples(examples, "ClozeMultitask.rehearse") - docs = [eg.predicted for eg in examples] - predictions, bp_predictions = self.model.begin_update() - loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions) - bp_predictions(d_predictions) - if sgd is not None: - self.finish_update(sgd) - if losses is not None: - losses[self.name] += loss - return losses - - def add_label(self, label): - raise NotImplementedError diff --git a/spacy/pipeline/pipe.pyx b/spacy/pipeline/pipe.pyx index 8407acc45..c5650382b 100644 --- a/spacy/pipeline/pipe.pyx +++ b/spacy/pipeline/pipe.pyx @@ -19,13 +19,6 @@ cdef class Pipe: DOCS: https://spacy.io/api/pipe """ - @classmethod - def __init_subclass__(cls, **kwargs): - """Raise a warning if an inheriting class implements 'begin_training' - (from v2) instead of the new 'initialize' method (from v3)""" - if hasattr(cls, "begin_training"): - warnings.warn(Warnings.W088.format(name=cls.__name__)) - def __call__(self, Doc doc) -> Doc: """Apply the pipe to one document. The document is modified in place, and returned. This usually happens under the hood when the nlp object diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 93a7ee796..42feeb277 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -3,7 +3,9 @@ from typing import Dict, Iterable, Optional, Callable, List, Union from itertools import islice import srsly -from thinc.api import Model, SequenceCategoricalCrossentropy, Config +from thinc.api import Model, Config +from thinc.legacy import LegacySequenceCategoricalCrossentropy + from thinc.types import Floats2d, Ints1d from ..tokens.doc cimport Doc @@ -161,7 +163,7 @@ class SentenceRecognizer(Tagger): """ validate_examples(examples, "SentenceRecognizer.get_loss") labels = self.labels - loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False) + loss_func = LegacySequenceCategoricalCrossentropy(names=labels, normalize=False) truths = [] for eg in examples: eg_truth = [] diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 3b4715ce5..e12f116af 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -2,7 +2,8 @@ from typing import Callable, Dict, Iterable, List, Optional, Union import numpy import srsly -from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config +from thinc.api import Model, set_dropout_rate, Config +from thinc.legacy import LegacySequenceCategoricalCrossentropy from thinc.types import Floats2d, Ints1d import warnings from itertools import islice @@ -244,7 +245,7 @@ class Tagger(TrainablePipe): DOCS: https://spacy.io/api/tagger#rehearse """ - loss_func = SequenceCategoricalCrossentropy() + loss_func = LegacySequenceCategoricalCrossentropy() if losses is None: losses = {} losses.setdefault(self.name, 0.0) @@ -275,7 +276,7 @@ class Tagger(TrainablePipe): DOCS: https://spacy.io/api/tagger#get_loss """ validate_examples(examples, "Tagger.get_loss") - loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) + loss_func = LegacySequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) # Convert empty tag "" to missing value None so that both misaligned # tokens and tokens with missing annotation have the default missing # value None. diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index bdf933c10..d64be66f6 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -155,11 +155,8 @@ class MultiLabel_TextCategorizer(TextCategorizer): name (str): The component instance name, used to add entries to the losses during training. threshold (float): Cutoff to consider a prediction "positive". -<<<<<<< HEAD - save_activations (bool): save model activations in Doc when annotating. -======= scorer (Optional[Callable]): The scoring method. ->>>>>>> upstream/master + save_activations (bool): save model activations in Doc when annotating. DOCS: https://spacy.io/api/textcategorizer#init """ diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 340334b1a..9d7b258c6 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -10,12 +10,12 @@ import random import srsly from thinc.api import get_ops, set_dropout_rate, CupyOps, NumpyOps -from thinc.extra.search cimport Beam import numpy.random import numpy import warnings from ._parser_internals.stateclass cimport StateClass +from ._parser_internals.search cimport Beam from ..ml.parser_model cimport alloc_activations, free_activations from ..ml.parser_model cimport predict_states, arg_max_if_valid from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 2be286a57..b9c4ef715 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -1,6 +1,10 @@ import pytest from spacy.util import get_lang_class +import functools from hypothesis import settings +import inspect +import importlib +import sys # Functionally disable deadline settings for tests # to prevent spurious test failures in CI builds. @@ -47,6 +51,33 @@ def pytest_runtest_setup(item): pytest.skip("not referencing any issues") +# Decorator for Cython-built tests +# https://shwina.github.io/cython-testing/ +def cytest(func): + """ + Wraps `func` in a plain Python function. + """ + + @functools.wraps(func) + def wrapped(*args, **kwargs): + bound = inspect.signature(func).bind(*args, **kwargs) + return func(*bound.args, **bound.kwargs) + + return wrapped + + +def register_cython_tests(cython_mod_name: str, test_mod_name: str): + """ + Registers all callables with name `test_*` in Cython module `cython_mod_name` + as attributes in module `test_mod_name`, making them discoverable by pytest. + """ + cython_mod = importlib.import_module(cython_mod_name) + for name in dir(cython_mod): + item = getattr(cython_mod, name) + if callable(item) and name.startswith("test_"): + setattr(sys.modules[test_mod_name], name, item) + + # Fixtures for language tokenizers (languages sorted alphabetically) diff --git a/spacy/tests/doc/test_array.py b/spacy/tests/doc/test_array.py index c334cc6eb..1f2d7d999 100644 --- a/spacy/tests/doc/test_array.py +++ b/spacy/tests/doc/test_array.py @@ -123,14 +123,14 @@ def test_doc_from_array_heads_in_bounds(en_vocab): # head before start arr = doc.to_array(["HEAD"]) - arr[0] = -1 + arr[0] = numpy.int32(-1).astype(numpy.uint64) doc_from_array = Doc(en_vocab, words=words) with pytest.raises(ValueError): doc_from_array.from_array(["HEAD"], arr) # head after end arr = doc.to_array(["HEAD"]) - arr[0] = 5 + arr[0] = numpy.int32(5).astype(numpy.uint64) doc_from_array = Doc(en_vocab, words=words) with pytest.raises(ValueError): doc_from_array.from_array(["HEAD"], arr) diff --git a/spacy/tests/doc/test_span_group.py b/spacy/tests/doc/test_span_group.py index da3c24908..5e8bea127 100644 --- a/spacy/tests/doc/test_span_group.py +++ b/spacy/tests/doc/test_span_group.py @@ -1,7 +1,10 @@ +from typing import List + import pytest from random import Random from spacy.matcher import Matcher -from spacy.tokens import Span, SpanGroup +from spacy.tokens import Span, SpanGroup, Doc +from spacy.util import filter_spans @pytest.fixture @@ -242,3 +245,13 @@ def test_span_group_extend(doc): def test_span_group_dealloc(span_group): with pytest.raises(AttributeError): print(span_group.doc) + + +@pytest.mark.issue(11975) +def test_span_group_typing(doc: Doc): + """Tests whether typing of `SpanGroup` as `Iterable[Span]`-like object is accepted by mypy.""" + span_group: SpanGroup = doc.spans["SPANS"] + spans: List[Span] = list(span_group) + for i, span in enumerate(span_group): + assert span == span_group[i] == spans[i] + filter_spans(span_group) diff --git a/spacy/tests/doc/test_underscore.py b/spacy/tests/doc/test_underscore.py index b934221af..d23bb3162 100644 --- a/spacy/tests/doc/test_underscore.py +++ b/spacy/tests/doc/test_underscore.py @@ -3,6 +3,10 @@ from mock import Mock from spacy.tokens import Doc, Span, Token from spacy.tokens.underscore import Underscore +# Helper functions +def _get_tuple(s: Span): + return "._.", "span_extension", s.start_char, s.end_char, s.label, s.kb_id, s.id + @pytest.fixture(scope="function", autouse=True) def clean_underscore(): @@ -171,3 +175,118 @@ def test_underscore_docstring(en_vocab): doc = Doc(en_vocab, words=["hello", "world"]) assert test_method.__doc__ == "I am a docstring" assert doc._.test_docstrings.__doc__.rsplit(". ")[-1] == "I am a docstring" + + +def test_underscore_for_unique_span(en_tokenizer): + """Test that spans with the same boundaries but with different labels are uniquely identified (see #9706).""" + Doc.set_extension(name="doc_extension", default=None) + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text = "Hello, world!" + doc = en_tokenizer(text) + span_1 = Span(doc, 0, 2, "SPAN_1") + span_2 = Span(doc, 0, 2, "SPAN_2") + + # Set custom extensions + doc._.doc_extension = "doc extension" + doc[0]._.token_extension = "token extension" + span_1._.span_extension = "span_1 extension" + span_2._.span_extension = "span_2 extension" + + # Assert extensions + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change label of span and assert extensions + span_1.label_ = "NEW_LABEL" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change KB_ID and assert extensions + span_1.kb_id_ = "KB_ID" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "span_2 extension" + + # Change extensions and assert + span_2._.span_extension = "updated span_2 extension" + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension" + + # Change span ID and assert extensions + span_2.id = 2 + assert doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert doc.user_data[_get_tuple(span_2)] == "updated span_2 extension" + + # Assert extensions with original key + assert doc.user_data[("._.", "doc_extension", None, None)] == "doc extension" + assert doc.user_data[("._.", "token_extension", 0, None)] == "token extension" + + +def test_underscore_for_unique_span_from_docs(en_tokenizer): + """Test that spans in the user_data keep the same data structure when using Doc.from_docs""" + Span.set_extension(name="span_extension", default=None) + Token.set_extension(name="token_extension", default=None) + + # Initialize doc + text_1 = "Hello, world!" + doc_1 = en_tokenizer(text_1) + span_1a = Span(doc_1, 0, 2, "SPAN_1a") + span_1b = Span(doc_1, 0, 2, "SPAN_1b") + + text_2 = "This is a test." + doc_2 = en_tokenizer(text_2) + span_2a = Span(doc_2, 0, 3, "SPAN_2a") + + # Set custom extensions + doc_1[0]._.token_extension = "token_1" + doc_2[1]._.token_extension = "token_2" + span_1a._.span_extension = "span_1a extension" + span_1b._.span_extension = "span_1b extension" + span_2a._.span_extension = "span_2a extension" + + doc = Doc.from_docs([doc_1, doc_2]) + # Assert extensions + assert doc_1.user_data[_get_tuple(span_1a)] == "span_1a extension" + assert doc_1.user_data[_get_tuple(span_1b)] == "span_1b extension" + assert doc_2.user_data[_get_tuple(span_2a)] == "span_2a extension" + + # Check extensions on merged doc + assert doc.user_data[_get_tuple(span_1a)] == "span_1a extension" + assert doc.user_data[_get_tuple(span_1b)] == "span_1b extension" + assert ( + doc.user_data[ + ( + "._.", + "span_extension", + span_2a.start_char + len(doc_1.text) + 1, + span_2a.end_char + len(doc_1.text) + 1, + span_2a.label, + span_2a.kb_id, + span_2a.id, + ) + ] + == "span_2a extension" + ) + + +def test_underscore_for_unique_span_as_span(en_tokenizer): + """Test that spans in the user_data keep the same data structure when using Span.as_doc""" + Span.set_extension(name="span_extension", default=None) + + # Initialize doc + text = "Hello, world!" + doc = en_tokenizer(text) + span_1 = Span(doc, 0, 2, "SPAN_1") + span_2 = Span(doc, 0, 2, "SPAN_2") + + # Set custom extensions + span_1._.span_extension = "span_1 extension" + span_2._.span_extension = "span_2 extension" + + span_doc = span_1.as_doc(copy_user_data=True) + + # Assert extensions + assert span_doc.user_data[_get_tuple(span_1)] == "span_1 extension" + assert span_doc.user_data[_get_tuple(span_2)] == "span_2 extension" diff --git a/spacy/tests/parser/_search.pyx b/spacy/tests/parser/_search.pyx new file mode 100644 index 000000000..23fc81644 --- /dev/null +++ b/spacy/tests/parser/_search.pyx @@ -0,0 +1,119 @@ +# cython: infer_types=True, binding=True +from spacy.pipeline._parser_internals.search cimport Beam, MaxViolation +from spacy.typedefs cimport class_t, weight_t +from cymem.cymem cimport Pool + +from ..conftest import cytest +import pytest + +cdef struct TestState: + int length + int x + Py_UNICODE* string + + +cdef int transition(void* dest, void* src, class_t clas, void* extra_args) except -1: + dest_state = dest + src_state = src + dest_state.length = src_state.length + dest_state.x = src_state.x + dest_state.x += clas + if extra_args != NULL: + dest_state.string = extra_args + else: + dest_state.string = src_state.string + + +cdef void* initialize(Pool mem, int n, void* extra_args) except NULL: + state = mem.alloc(1, sizeof(TestState)) + state.length = n + state.x = 1 + if extra_args == NULL: + state.string = u'default' + else: + state.string = extra_args + return state + + +cdef int destroy(Pool mem, void* state, void* extra_args) except -1: + state = state + mem.free(state) + +@cytest +@pytest.mark.parametrize("nr_class,beam_width", + [ + (2, 3), + (3, 6), + (4, 20), + ] +) +def test_init(nr_class, beam_width): + b = Beam(nr_class, beam_width) + assert b.size == 1 + assert b.width == beam_width + assert b.nr_class == nr_class + +@cytest +def test_init_violn(): + MaxViolation() + +@cytest +@pytest.mark.parametrize("nr_class,beam_width,length", + [ + (2, 3, 3), + (3, 6, 15), + (4, 20, 32), + ] +) +def test_initialize(nr_class, beam_width, length): + b = Beam(nr_class, beam_width) + b.initialize(initialize, destroy, length, NULL) + for i in range(b.width): + s = b.at(i) + assert s.length == length, s.length + assert s.string == 'default' + + +@cytest +@pytest.mark.parametrize("nr_class,beam_width,length,extra", + [ + (2, 3, 4, None), + (3, 6, 15, u"test beam 1"), + ] +) +def test_initialize_extra(nr_class, beam_width, length, extra): + b = Beam(nr_class, beam_width) + if extra is None: + b.initialize(initialize, destroy, length, NULL) + else: + b.initialize(initialize, destroy, length, extra) + for i in range(b.width): + s = b.at(i) + assert s.length == length + + +@cytest +@pytest.mark.parametrize("nr_class,beam_width,length", + [ + (3, 6, 15), + (4, 20, 32), + ] +) +def test_transition(nr_class, beam_width, length): + b = Beam(nr_class, beam_width) + b.initialize(initialize, destroy, length, NULL) + b.set_cell(0, 2, 30, True, 0) + b.set_cell(0, 1, 42, False, 0) + b.advance(transition, NULL, NULL) + assert b.size == 1, b.size + assert b.score == 30, b.score + s = b.at(0) + assert s.x == 3 + assert b._states[0].score == 30, b._states[0].score + b.set_cell(0, 1, 10, True, 0) + b.set_cell(0, 2, 20, True, 0) + b.advance(transition, NULL, NULL) + assert b._states[0].score == 50, b._states[0].score + assert b._states[1].score == 40 + s = b.at(0) + assert s.x == 5 diff --git a/spacy/tests/parser/test_search.py b/spacy/tests/parser/test_search.py new file mode 100644 index 000000000..136c3a11b --- /dev/null +++ b/spacy/tests/parser/test_search.py @@ -0,0 +1,3 @@ +from ..conftest import register_cython_tests + +register_cython_tests("spacy.tests.parser._search", __name__) diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index ad2e56729..5eeb55aa2 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -62,10 +62,45 @@ def test_initialize_from_labels(): nlp2 = Language() lemmatizer2 = nlp2.add_pipe("trainable_lemmatizer") lemmatizer2.initialize( - get_examples=lambda: train_examples, + # We want to check that the strings in replacement nodes are + # added to the string store. Avoid that they get added through + # the examples. + get_examples=lambda: train_examples[:1], labels=lemmatizer.label_data, ) assert lemmatizer2.tree2label == {1: 0, 3: 1, 4: 2, 6: 3} + assert lemmatizer2.label_data == { + "trees": [ + {"orig": "S", "subst": "s"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 0, + "suffix_tree": 4294967295, + }, + {"orig": "s", "subst": ""}, + { + "prefix_len": 0, + "suffix_len": 1, + "prefix_tree": 4294967295, + "suffix_tree": 2, + }, + { + "prefix_len": 0, + "suffix_len": 0, + "prefix_tree": 4294967295, + "suffix_tree": 4294967295, + }, + {"orig": "E", "subst": "e"}, + { + "prefix_len": 1, + "suffix_len": 0, + "prefix_tree": 5, + "suffix_tree": 4294967295, + }, + ], + "labels": (1, 3, 4, 6), + } def test_no_data(): diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 4dd7bae16..9b9786f04 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -529,17 +529,6 @@ def test_pipe_label_data_no_labels(pipe): assert "labels" not in get_arg_names(initialize) -def test_warning_pipe_begin_training(): - with pytest.warns(UserWarning, match="begin_training"): - - class IncompatPipe(TrainablePipe): - def __init__(self): - ... - - def begin_training(*args, **kwargs): - ... - - def test_pipe_methods_initialize(): """Test that the [initialize] config reflects the components correctly.""" nlp = Language() diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 2e706458f..c6768a3fd 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -5,6 +5,7 @@ from typing import Tuple, List, Dict, Any import pkg_resources import time +import spacy import numpy import pytest import srsly @@ -32,6 +33,7 @@ from spacy.cli.package import _is_permitted_package_name from spacy.cli.project.remote_storage import RemoteStorage from spacy.cli.project.run import _check_requirements from spacy.cli.validate import get_model_pkgs +from spacy.cli.apply import apply from spacy.cli.find_threshold import find_threshold from spacy.lang.en import English from spacy.lang.nl import Dutch @@ -123,6 +125,25 @@ def test_issue7055(): assert "model" in filled_cfg["components"]["ner"] +@pytest.mark.issue(11235) +def test_issue11235(): + """ + Test that the cli handles interpolation in the directory names correctly when loading project config. + """ + lang_var = "en" + variables = {"lang": lang_var} + commands = [{"name": "x", "script": ["hello ${vars.lang}"]}] + directories = ["cfg", "${vars.lang}_model"] + project = {"commands": commands, "vars": variables, "directories": directories} + with make_tempdir() as d: + srsly.write_yaml(d / "project.yml", project) + cfg = load_project_config(d) + # Check that the directories are interpolated and created correctly + assert os.path.exists(d / "cfg") + assert os.path.exists(d / f"{lang_var}_model") + assert cfg["commands"][0]["script"][0] == f"hello {lang_var}" + + def test_cli_info(): nlp = Dutch() nlp.add_pipe("textcat") @@ -866,6 +887,82 @@ def test_span_length_freq_dist_output_must_be_correct(): assert list(span_freqs.keys()) == [3, 1, 4, 5, 2] +def test_applycli_empty_dir(): + with make_tempdir() as data_path: + output = data_path / "test.spacy" + apply(data_path, output, "blank:en", "text", 1, 1) + + +def test_applycli_docbin(): + with make_tempdir() as data_path: + output = data_path / "testout.spacy" + nlp = spacy.blank("en") + doc = nlp("testing apply cli.") + # test empty DocBin case + docbin = DocBin() + docbin.to_disk(data_path / "testin.spacy") + apply(data_path, output, "blank:en", "text", 1, 1) + docbin.add(doc) + docbin.to_disk(data_path / "testin.spacy") + apply(data_path, output, "blank:en", "text", 1, 1) + + +def test_applycli_jsonl(): + with make_tempdir() as data_path: + output = data_path / "testout.spacy" + data = [{"field": "Testing apply cli.", "key": 234}] + data2 = [{"field": "234"}] + srsly.write_jsonl(data_path / "test.jsonl", data) + apply(data_path, output, "blank:en", "field", 1, 1) + srsly.write_jsonl(data_path / "test2.jsonl", data2) + apply(data_path, output, "blank:en", "field", 1, 1) + + +def test_applycli_txt(): + with make_tempdir() as data_path: + output = data_path / "testout.spacy" + with open(data_path / "test.foo", "w") as ftest: + ftest.write("Testing apply cli.") + apply(data_path, output, "blank:en", "text", 1, 1) + + +def test_applycli_mixed(): + with make_tempdir() as data_path: + output = data_path / "testout.spacy" + text = "Testing apply cli" + nlp = spacy.blank("en") + doc = nlp(text) + jsonl_data = [{"text": text}] + srsly.write_jsonl(data_path / "test.jsonl", jsonl_data) + docbin = DocBin() + docbin.add(doc) + docbin.to_disk(data_path / "testin.spacy") + with open(data_path / "test.txt", "w") as ftest: + ftest.write(text) + apply(data_path, output, "blank:en", "text", 1, 1) + # Check whether it worked + result = list(DocBin().from_disk(output).get_docs(nlp.vocab)) + assert len(result) == 3 + for doc in result: + assert doc.text == text + + +def test_applycli_user_data(): + Doc.set_extension("ext", default=0) + val = ("ext", 0) + with make_tempdir() as data_path: + output = data_path / "testout.spacy" + nlp = spacy.blank("en") + doc = nlp("testing apply cli.") + doc._.ext = val + docbin = DocBin(store_user_data=True) + docbin.add(doc) + docbin.to_disk(data_path / "testin.spacy") + apply(data_path, output, "blank:en", "", 1, 1) + result = list(DocBin().from_disk(output).get_docs(nlp.vocab)) + assert result[0]._.ext == val + + def test_local_remote_storage(): with make_tempdir() as d: filename = "a.txt" diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index bf3da0ce4..25af6ca6a 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -359,6 +359,7 @@ cdef class Doc: for annot in annotations: if annot: if annot is heads or annot is sent_starts or annot is ent_iobs: + annot = numpy.array(annot, dtype=numpy.int32).astype(numpy.uint64) for i in range(len(words)): if attrs.ndim == 1: attrs[i] = annot[i] @@ -1177,13 +1178,22 @@ cdef class Doc: if "user_data" not in exclude: for key, value in doc.user_data.items(): - if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key + if isinstance(key, tuple) and len(key) >= 4 and key[0] == "._.": + data_type = key[0] + name = key[1] + start = key[2] + end = key[3] if start is not None or end is not None: start += char_offset if end is not None: end += char_offset - concat_user_data[(data_type, name, start, end)] = copy.copy(value) + _label = key[4] + _kb_id = key[5] + _span_id = key[6] + concat_user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value) + else: + concat_user_data[(data_type, name, start, end)] = copy.copy(value) + else: warnings.warn(Warnings.W101.format(name=name)) else: @@ -1564,6 +1574,7 @@ cdef class Doc: for j, (attr, annot) in enumerate(token_annotations.items()): if attr is HEAD: + annot = numpy.array(annot, dtype=numpy.int32).astype(numpy.uint64) for i in range(len(words)): array[i, j] = annot[i] elif attr is MORPH: @@ -1627,7 +1638,11 @@ cdef class Doc: Span.set_extension(span_attr) for span_data in doc_json["underscore_span"][span_attr]: value = span_data["value"] - self.char_span(span_data["start"], span_data["end"])._.set(span_attr, value) + span = self.char_span(span_data["start"], span_data["end"]) + span.label = span_data["label"] + span.kb_id = span_data["kb_id"] + span.id = span_data["id"] + span._.set(span_attr, value) return self def to_json(self, underscore=None): @@ -1705,13 +1720,16 @@ cdef class Doc: if attr not in data["underscore_token"]: data["underscore_token"][attr] = [] data["underscore_token"][attr].append({"start": start, "value": value}) - # Span attribute - elif start is not None and end is not None: + # Else span attribute + elif end is not None: + _label = data_key[4] + _kb_id = data_key[5] + _span_id = data_key[6] if "underscore_span" not in data: data["underscore_span"] = {} if attr not in data["underscore_span"]: data["underscore_span"][attr] = [] - data["underscore_span"][attr].append({"start": start, "end": end, "value": value}) + data["underscore_span"][attr].append({"start": start, "end": end, "value": value, "label": _label, "kb_id": _kb_id, "id":_span_id}) for attr in underscore: if attr not in user_keys: diff --git a/spacy/tokens/span.pyi b/spacy/tokens/span.pyi index abda49361..5168f3b03 100644 --- a/spacy/tokens/span.pyi +++ b/spacy/tokens/span.pyi @@ -93,8 +93,8 @@ class Span: self, start_idx: int, end_idx: int, - label: int = ..., - kb_id: int = ..., + label: Union[int, str] = ..., + kb_id: Union[int, str] = ..., vector: Optional[Floats1d] = ..., ) -> Span: ... @property diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 5530dd127..b605434fd 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -218,11 +218,10 @@ cdef class Span: cdef SpanC* span_c = self.span_c() """Custom extension attributes registered via `set_extension`.""" return Underscore(Underscore.span_extensions, self, - start=span_c.start_char, end=span_c.end_char) + start=span_c.start_char, end=span_c.end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) def as_doc(self, *, bint copy_user_data=False, array_head=None, array=None): """Create a `Doc` object with a copy of the `Span`'s data. - copy_user_data (bool): Whether or not to copy the original doc's user data. array_head (tuple): `Doc` array attrs, can be passed in to speed up computation. array (ndarray): `Doc` as array, can be passed in to speed up computation. @@ -275,12 +274,22 @@ cdef class Span: char_offset = self.start_char for key, value in self.doc.user_data.items(): if isinstance(key, tuple) and len(key) == 4 and key[0] == "._.": - data_type, name, start, end = key + data_type = key[0] + name = key[1] + start = key[2] + end = key[3] if start is not None or end is not None: start -= char_offset + # Check if Span object if end is not None: end -= char_offset - user_data[(data_type, name, start, end)] = copy.copy(value) + _label = key[4] + _kb_id = key[5] + _span_id = key[6] + user_data[(data_type, name, start, end, _label, _kb_id, _span_id)] = copy.copy(value) + # Else Token object + else: + user_data[(data_type, name, start, end)] = copy.copy(value) else: user_data[key] = copy.copy(value) doc.user_data = user_data @@ -309,7 +318,7 @@ cdef class Span: for ancestor in ancestors: ancestor_i = ancestor.i - span_c.start if ancestor_i in range(length): - array[i, head_col] = ancestor_i - i + array[i, head_col] = numpy.int32(ancestor_i - i).astype(numpy.uint64) # if there is no appropriate ancestor, define a new artificial root value = array[i, head_col] @@ -317,7 +326,7 @@ cdef class Span: new_root = old_to_new_root.get(ancestor_i, None) if new_root is not None: # take the same artificial root as a previous token from the same sentence - array[i, head_col] = new_root - i + array[i, head_col] = numpy.int32(new_root - i).astype(numpy.uint64) else: # set this token as the new artificial root array[i, head_col] = 0 @@ -781,21 +790,36 @@ cdef class Span: return self.span_c().label def __set__(self, attr_t label): - self.span_c().label = label + if label != self.span_c().label : + old_label = self.span_c().label + self.span_c().label = label + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=old_label, kb_id=self.kb_id, span_id=self.id) + Underscore._replace_keys(old, new) property kb_id: def __get__(self): return self.span_c().kb_id def __set__(self, attr_t kb_id): - self.span_c().kb_id = kb_id + if kb_id != self.span_c().kb_id : + old_kb_id = self.span_c().kb_id + self.span_c().kb_id = kb_id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=old_kb_id, span_id=self.id) + Underscore._replace_keys(old, new) property id: def __get__(self): return self.span_c().id def __set__(self, attr_t id): - self.span_c().id = id + if id != self.span_c().id : + old_id = self.span_c().id + self.span_c().id = id + new = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=self.id) + old = Underscore(Underscore.span_extensions, self, start=self.span_c().start_char, end=self.span_c().end_char, label=self.label, kb_id=self.kb_id, span_id=old_id) + Underscore._replace_keys(old, new) property ent_id: """Alias for the span's ID.""" diff --git a/spacy/tokens/span_group.pyi b/spacy/tokens/span_group.pyi index 21cd124ab..0b4aa83aa 100644 --- a/spacy/tokens/span_group.pyi +++ b/spacy/tokens/span_group.pyi @@ -18,6 +18,7 @@ class SpanGroup: def doc(self) -> Doc: ... @property def has_overlap(self) -> bool: ... + def __iter__(self): ... def __len__(self) -> int: ... def append(self, span: Span) -> None: ... def extend(self, spans: Iterable[Span]) -> None: ... diff --git a/spacy/tokens/span_group.pyx b/spacy/tokens/span_group.pyx index 7caa01ee7..7325c1fa7 100644 --- a/spacy/tokens/span_group.pyx +++ b/spacy/tokens/span_group.pyx @@ -159,6 +159,16 @@ cdef class SpanGroup: return self._concat(other) return NotImplemented + def __iter__(self): + """ + Iterate over the spans in this SpanGroup. + YIELDS (Span): A span in this SpanGroup. + + DOCS: https://spacy.io/api/spangroup#iter + """ + for i in range(self.c.size()): + yield self[i] + def append(self, Span span): """Add a span to the group. The span must refer to the same Doc object as the span group. diff --git a/spacy/tokens/underscore.py b/spacy/tokens/underscore.py index e9a4e1862..f2f357441 100644 --- a/spacy/tokens/underscore.py +++ b/spacy/tokens/underscore.py @@ -2,10 +2,10 @@ from typing import Dict, Any, List, Optional, Tuple, Union, TYPE_CHECKING import functools import copy from ..errors import Errors +from .span import Span if TYPE_CHECKING: from .doc import Doc - from .span import Span from .token import Token @@ -25,6 +25,9 @@ class Underscore: obj: Union["Doc", "Span", "Token"], start: Optional[int] = None, end: Optional[int] = None, + label: int = 0, + kb_id: int = 0, + span_id: int = 0, ): object.__setattr__(self, "_extensions", extensions) object.__setattr__(self, "_obj", obj) @@ -36,6 +39,10 @@ class Underscore: object.__setattr__(self, "_doc", obj.doc) object.__setattr__(self, "_start", start) object.__setattr__(self, "_end", end) + if type(obj) == Span: + object.__setattr__(self, "_label", label) + object.__setattr__(self, "_kb_id", kb_id) + object.__setattr__(self, "_span_id", span_id) def __dir__(self) -> List[str]: # Hack to enable autocomplete on custom extensions @@ -88,8 +95,39 @@ class Underscore: def has(self, name: str) -> bool: return name in self._extensions - def _get_key(self, name: str) -> Tuple[str, str, Optional[int], Optional[int]]: - return ("._.", name, self._start, self._end) + def _get_key( + self, name: str + ) -> Union[ + Tuple[str, str, Optional[int], Optional[int]], + Tuple[str, str, Optional[int], Optional[int], int, int, int], + ]: + if hasattr(self, "_label"): + return ( + "._.", + name, + self._start, + self._end, + self._label, + self._kb_id, + self._span_id, + ) + else: + return "._.", name, self._start, self._end + + @staticmethod + def _replace_keys(old_underscore: "Underscore", new_underscore: "Underscore"): + """ + This function is called by Span when its kb_id or label are re-assigned. + It checks if any user_data is stored for this span and replaces the keys + """ + for name in old_underscore._extensions: + old_key = old_underscore._get_key(name) + old_doc = old_underscore._doc + new_key = new_underscore._get_key(name) + if old_key != new_key and old_key in old_doc.user_data: + old_underscore._doc.user_data[ + new_key + ] = old_underscore._doc.user_data.pop(old_key) @classmethod def get_state(cls) -> Tuple[Dict[Any, Any], Dict[Any, Any], Dict[Any, Any]]: diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index dfd337b9e..95b0f0de9 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -443,26 +443,27 @@ def _annot2array(vocab, tok_annot, doc_annot): if key not in IDS: raise ValueError(Errors.E974.format(obj="token", key=key)) elif key in ["ORTH", "SPACY"]: - pass + continue elif key == "HEAD": attrs.append(key) - values.append([h-i if h is not None else 0 for i, h in enumerate(value)]) + row = [h-i if h is not None else 0 for i, h in enumerate(value)] elif key == "DEP": attrs.append(key) - values.append([vocab.strings.add(h) if h is not None else MISSING_DEP for h in value]) + row = [vocab.strings.add(h) if h is not None else MISSING_DEP for h in value] elif key == "SENT_START": attrs.append(key) - values.append([to_ternary_int(v) for v in value]) + row = [to_ternary_int(v) for v in value] elif key == "MORPH": attrs.append(key) - values.append([vocab.morphology.add(v) for v in value]) + row = [vocab.morphology.add(v) for v in value] else: attrs.append(key) if not all(isinstance(v, str) for v in value): types = set([type(v) for v in value]) raise TypeError(Errors.E969.format(field=key, types=types)) from None - values.append([vocab.strings.add(v) for v in value]) - array = numpy.asarray(values, dtype="uint64") + row = [vocab.strings.add(v) for v in value] + values.append([numpy.array(v, dtype=numpy.int32).astype(numpy.uint64) if v < 0 else v for v in row]) + array = numpy.array(values, dtype=numpy.uint64) return attrs, array.T diff --git a/spacy/util.py b/spacy/util.py index 4bdde1ad1..d674fb9ce 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -40,7 +40,7 @@ except ImportError: from .symbols import ORTH from .compat import cupy, CudaStream, is_windows, importlib_metadata -from .errors import Errors, Warnings, OLD_MODEL_SHORTCUTS +from .errors import Errors, Warnings from . import about if TYPE_CHECKING: @@ -427,8 +427,6 @@ def load_model( return load_model_from_path(Path(name), **kwargs) # type: ignore[arg-type] elif hasattr(name, "exists"): # Path or Path-like to model data return load_model_from_path(name, **kwargs) # type: ignore[arg-type] - if name in OLD_MODEL_SHORTCUTS: - raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name])) # type: ignore[index] raise IOError(Errors.E050.format(name=name)) diff --git a/website/UNIVERSE.md b/website/UNIVERSE.md index 770bbde13..c3e49ba43 100644 --- a/website/UNIVERSE.md +++ b/website/UNIVERSE.md @@ -51,7 +51,7 @@ markup is correct. "import spacy", "import package_name", "", - "nlp = spacy.load('en')", + "nlp = spacy.load('en_core_web_sm')", "nlp.add_pipe(package_name)" ], "code_language": "python", diff --git a/website/docs/api/cli.md b/website/docs/api/cli.md index 92a123241..275e37ee0 100644 --- a/website/docs/api/cli.md +++ b/website/docs/api/cli.md @@ -12,6 +12,7 @@ menu: - ['train', 'train'] - ['pretrain', 'pretrain'] - ['evaluate', 'evaluate'] + - ['apply', 'apply'] - ['find-threshold', 'find-threshold'] - ['assemble', 'assemble'] - ['package', 'package'] @@ -1162,6 +1163,37 @@ $ python -m spacy evaluate [model] [data_path] [--output] [--code] [--gold-prepr | `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | | **CREATES** | Training results and optional metrics and visualizations. | +## apply {#apply new="3.5" tag="command"} + +Applies a trained pipeline to data and stores the resulting annotated documents +in a `DocBin`. The input can be a single file or a directory. The recognized +input formats are: + +1. `.spacy` +2. `.jsonl` containing a user specified `text_key` +3. Files with any other extension are assumed to be plain text files containing + a single document. + +When a directory is provided it is traversed recursively to collect all files. + +```cli +$ python -m spacy apply [model] [data-path] [output-file] [--code] [--text-key] [--force-overwrite] [--gpu-id] [--batch-size] [--n-process] +``` + +| Name | Description | +| ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `model` | Pipeline to apply to the data. Can be a package or a path to a data directory. ~~str (positional)~~ | +| `data_path` | Location of data to be evaluated in spaCy's [binary format](/api/data-formats#training), jsonl, or plain text. ~~Path (positional)~~ | +| `output-file`, `-o` | Output `DocBin` path. ~~str (positional)~~ | +| `--code`, `-c` 3 | Path to Python file with additional code to be imported. Allows [registering custom functions](/usage/training#custom-functions) for new architectures. ~~Optional[Path] \(option)~~ | +| `--text-key`, `-tk` | The key for `.jsonl` files to use to grab the texts from. Defaults to `text`. ~~Optional[str] \(option)~~ | +| `--force-overwrite`, `-F` | If the provided `output-file` already exists, then force `apply` to overwrite it. If this is `False` (default) then quits with a warning instead. ~~bool (flag)~~ | +| `--gpu-id`, `-g` | GPU to use, if any. Defaults to `-1` for CPU. ~~int (option)~~ | +| `--batch-size`, `-b` | Batch size to use for prediction. Defaults to `1`. ~~int (option)~~ | +| `--n-process`, `-n` | Number of processes to use for prediction. Defaults to `1`. ~~int (option)~~ | +| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | +| **CREATES** | A `DocBin` with the annotations from the `model` for all the files found in `data-path`. | + ## find-threshold {#find-threshold new="3.5" tag="command"} Runs prediction trials for a trained model with varying tresholds to maximize diff --git a/website/docs/api/dependencyparser.md b/website/docs/api/dependencyparser.md index 27e315592..c30d39b57 100644 --- a/website/docs/api/dependencyparser.md +++ b/website/docs/api/dependencyparser.md @@ -169,12 +169,6 @@ arguments it receives via the [`[initialize.components]`](/api/data-formats#config-initialize) block in the config. - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/api/entitylinker.md b/website/docs/api/entitylinker.md index 10132acd9..b116c4be4 100644 --- a/website/docs/api/entitylinker.md +++ b/website/docs/api/entitylinker.md @@ -200,12 +200,6 @@ knowledge base. This argument should be a function that takes a `Vocab` instance and creates the `KnowledgeBase`, ensuring that the strings of the knowledge base are synced with the current vocab. - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/api/entityrecognizer.md b/website/docs/api/entityrecognizer.md index a535e8316..06828eb04 100644 --- a/website/docs/api/entityrecognizer.md +++ b/website/docs/api/entityrecognizer.md @@ -165,12 +165,6 @@ arguments it receives via the [`[initialize.components]`](/api/data-formats#config-initialize) block in the config. - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/api/language.md b/website/docs/api/language.md index ad0ac2a46..4d568df62 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -259,15 +259,6 @@ either in the [config](/usage/training#config), or by calling [`pipe.add_label`](/api/pipe#add_label) for each possible output label (e.g. for the tagger or textcat). - - -This method was previously called `begin_training`. It now also takes a -**function** that is called with no arguments and returns a sequence of -[`Example`](/api/example) objects instead of tuples of `Doc` and `GoldParse` -objects. - - - > #### Example > > ```python diff --git a/website/docs/api/lexeme.md b/website/docs/api/lexeme.md index cd4086562..e13f25209 100644 --- a/website/docs/api/lexeme.md +++ b/website/docs/api/lexeme.md @@ -138,7 +138,7 @@ The L2 norm of the lexeme's vector representation. | `prefix` | Length-N substring from the start of the word. Defaults to `N=1`. ~~int~~ | | `prefix_` | Length-N substring from the start of the word. Defaults to `N=1`. ~~str~~ | | `suffix` | Length-N substring from the end of the word. Defaults to `N=3`. ~~int~~ | -| `suffix_` | Length-N substring from the start of the word. Defaults to `N=3`. ~~str~~ | +| `suffix_` | Length-N substring from the end of the word. Defaults to `N=3`. ~~str~~ | | `is_alpha` | Does the lexeme consist of alphabetic characters? Equivalent to `lexeme.text.isalpha()`. ~~bool~~ | | `is_ascii` | Does the lexeme consist of ASCII characters? Equivalent to `[any(ord(c) >= 128 for c in lexeme.text)]`. ~~bool~~ | | `is_digit` | Does the lexeme consist of digits? Equivalent to `lexeme.text.isdigit()`. ~~bool~~ | diff --git a/website/docs/api/pipe.md b/website/docs/api/pipe.md index 263942e3e..70a4648b6 100644 --- a/website/docs/api/pipe.md +++ b/website/docs/api/pipe.md @@ -152,12 +152,6 @@ network, setting up the label scheme based on the data. This method is typically called by [`Language.initialize`](/api/language#initialize). - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/api/spangroup.md b/website/docs/api/spangroup.md index 2d1cf73c4..bd9659acb 100644 --- a/website/docs/api/spangroup.md +++ b/website/docs/api/spangroup.md @@ -202,6 +202,23 @@ already present in the current span group. | `other` | The span group or spans to append. ~~Union[SpanGroup, Iterable[Span]]~~ | | **RETURNS** | The span group. ~~SpanGroup~~ | +## SpanGroup.\_\_iter\_\_ {#iter tag="method" new="3.5"} + +Iterate over the spans in this span group. + +> #### Example +> +> ```python +> doc = nlp("Their goi ng home") +> doc.spans["errors"] = [doc[0:1], doc[1:3]] +> for error_span in doc.spans["errors"]: +> print(error_span) +> ``` + +| Name | Description | +| ---------- | ----------------------------------- | +| **YIELDS** | A span in this span group. ~~Span~~ | + ## SpanGroup.append {#append tag="method"} Add a [`Span`](/api/span) object to the group. The span must refer to the same diff --git a/website/docs/api/tagger.md b/website/docs/api/tagger.md index 0d77d9bf4..102793377 100644 --- a/website/docs/api/tagger.md +++ b/website/docs/api/tagger.md @@ -142,12 +142,6 @@ arguments it receives via the [`[initialize.components]`](/api/data-formats#config-initialize) block in the config. - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/api/textcategorizer.md b/website/docs/api/textcategorizer.md index ed1205d8c..b69c87a28 100644 --- a/website/docs/api/textcategorizer.md +++ b/website/docs/api/textcategorizer.md @@ -187,12 +187,6 @@ arguments it receives via the [`[initialize.components]`](/api/data-formats#config-initialize) block in the config. - - -This method was previously called `begin_training`. - - - > #### Example > > ```python diff --git a/website/docs/usage/models.md b/website/docs/usage/models.md index 3b1558bd8..03d0d535c 100644 --- a/website/docs/usage/models.md +++ b/website/docs/usage/models.md @@ -342,22 +342,6 @@ The easiest way to download a trained pipeline is via spaCy's [`download`](/api/cli#download) command. It takes care of finding the best-matching package compatible with your spaCy installation. -> #### Important note for v3.0 -> -> Note that as of spaCy v3.0, shortcut links like `en` that create (potentially -> brittle) symlinks in your spaCy installation are **deprecated**. To download -> and load an installed pipeline package, use its full name: -> -> ```diff -> - python -m spacy download en -> + python -m spacy download en_core_web_sm -> ``` -> -> ```diff -> - nlp = spacy.load("en") -> + nlp = spacy.load("en_core_web_sm") -> ``` - ```cli # Download best-matching version of a package for your spaCy installation $ python -m spacy download en_core_web_sm @@ -489,17 +473,6 @@ spacy.cli.download("en_core_web_sm") To load a pipeline package, use [`spacy.load`](/api/top-level#spacy.load) with the package name or a path to the data directory: -> #### Important note for v3.0 -> -> Note that as of spaCy v3.0, shortcut links like `en` that create (potentially -> brittle) symlinks in your spaCy installation are **deprecated**. To download -> and load an installed pipeline package, use its full name: -> -> ```diff -> - python -m spacy download en -> + python -m spacy download en_core_web_sm -> ``` - ```python import spacy nlp = spacy.load("en_core_web_sm") # load package "en_core_web_sm" diff --git a/website/meta/sidebars.json b/website/meta/sidebars.json index 2d8745d77..339e4085b 100644 --- a/website/meta/sidebars.json +++ b/website/meta/sidebars.json @@ -45,7 +45,7 @@ { "text": "v2.x Documentation", "url": "https://v2.spacy.io" }, { "text": "Custom Solutions", - "url": "https://explosion.ai/spacy-tailored-pipelines" + "url": "https://explosion.ai/custom-solutions" } ] } diff --git a/website/meta/site.json b/website/meta/site.json index 360a72178..fa79d3c69 100644 --- a/website/meta/site.json +++ b/website/meta/site.json @@ -51,7 +51,7 @@ { "text": "Online Course", "url": "https://course.spacy.io" }, { "text": "Custom Solutions", - "url": "https://explosion.ai/spacy-tailored-pipelines" + "url": "https://explosion.ai/custom-solutions" } ] }, diff --git a/website/meta/universe.json b/website/meta/universe.json index 97b53e9c5..84314328d 100644 --- a/website/meta/universe.json +++ b/website/meta/universe.json @@ -1021,31 +1021,13 @@ "author_links": { "github": "mholtzscher" }, - "category": ["pipeline"] - }, - { - "id": "spacy-sentence-segmenter", - "title": "Sentence Segmenter", - "slogan": "Custom sentence segmentation for spaCy", - "code_example": [ - "from seg.newline.segmenter import NewLineSegmenter", - "import spacy", - "", - "nlseg = NewLineSegmenter()", - "nlp = spacy.load('en')", - "nlp.add_pipe(nlseg.set_sent_starts, name='sentence_segmenter', before='parser')", - "doc = nlp(my_doc_text)" - ], - "author": "tc64", - "author_links": { - "github": "tc64" - }, - "category": ["pipeline"] + "category": ["pipeline"], + "spacy_version": 2 }, { "id": "spacy_cld", "title": "spaCy-CLD", - "slogan": "Add language detection to your spaCy pipeline using CLD2", + "slogan": "Add language detection to your spaCy v2 pipeline using CLD2", "description": "spaCy-CLD operates on `Doc` and `Span` spaCy objects. When called on a `Doc` or `Span`, the object is given two attributes: `languages` (a list of up to 3 language codes) and `language_scores` (a dictionary mapping language codes to confidence scores between 0 and 1).\n\nspacy-cld is a little extension that wraps the [PYCLD2](https://github.com/aboSamoor/pycld2) Python library, which in turn wraps the [Compact Language Detector 2](https://github.com/CLD2Owners/cld2) C library originally built at Google for the Chromium project. CLD2 uses character n-grams as features and a Naive Bayes classifier to identify 80+ languages from Unicode text strings (or XML/HTML). It can detect up to 3 different languages in a given document, and reports a confidence score (reported in with each language.", "github": "nickdavidhaynes/spacy-cld", "pip": "spacy_cld", @@ -1065,7 +1047,8 @@ "author_links": { "github": "nickdavidhaynes" }, - "category": ["pipeline"] + "category": ["pipeline"], + "spacy_version": 2 }, { "id": "spacy-iwnlp", @@ -1139,7 +1122,8 @@ "github": "sammous" }, "category": ["pipeline"], - "tags": ["pos", "lemmatizer", "french"] + "tags": ["pos", "lemmatizer", "french"], + "spacy_version": 2 }, { "id": "lemmy", @@ -1333,8 +1317,8 @@ }, { "id": "neuralcoref", - "slogan": "State-of-the-art coreference resolution based on neural nets and spaCy", - "description": "This coreference resolution module is based on the super fast [spaCy](https://spacy.io/) parser and uses the neural net scoring model described in [Deep Reinforcement Learning for Mention-Ranking Coreference Models](http://cs.stanford.edu/people/kevclark/resources/clark-manning-emnlp2016-deep.pdf) by Kevin Clark and Christopher D. Manning, EMNLP 2016. Since ✨Neuralcoref v2.0, you can train the coreference resolution system on your own dataset — e.g., another language than English! — **provided you have an annotated dataset**. Note that to use neuralcoref with spaCy > 2.1.0, you'll have to install neuralcoref from source.", + "slogan": "State-of-the-art coreference resolution based on neural nets and spaCy v2", + "description": "This coreference resolution module is based on the super fast spaCy parser and uses the neural net scoring model described in [Deep Reinforcement Learning for Mention-Ranking Coreference Models](http://cs.stanford.edu/people/kevclark/resources/clark-manning-emnlp2016-deep.pdf) by Kevin Clark and Christopher D. Manning, EMNLP 2016. Since ✨Neuralcoref v2.0, you can train the coreference resolution system on your own dataset — e.g., another language than English! — **provided you have an annotated dataset**. Note that to use neuralcoref with spaCy > 2.1.0, you'll have to install neuralcoref from source, and v3+ is not supported.", "github": "huggingface/neuralcoref", "thumb": "https://i.imgur.com/j6FO9O6.jpg", "code_example": [ @@ -1355,7 +1339,8 @@ "github": "huggingface" }, "category": ["standalone", "conversational", "models"], - "tags": ["coref"] + "tags": ["coref"], + "spacy_version": 2 }, { "id": "neuralcoref-vizualizer", @@ -1431,7 +1416,7 @@ "import spacy", "import explacy", "", - "nlp = spacy.load('en')", + "nlp = spacy.load('en_core_web_sm')", "explacy.print_parse_info(nlp, 'The salad was surprisingly tasty.')" ], "author": "Tyler Neylon", @@ -1468,13 +1453,26 @@ "image": "https://jasonkessler.github.io/2012conventions0.0.2.2.png", "code_example": [ "import spacy", - "import scattertext as st", "", - "nlp = spacy.load('en')", - "corpus = st.CorpusFromPandas(convention_df,", - " category_col='party',", - " text_col='text',", - " nlp=nlp).build()" + "from scattertext import SampleCorpora, produce_scattertext_explorer", + "from scattertext import produce_scattertext_html", + "from scattertext.CorpusFromPandas import CorpusFromPandas", + "", + "nlp = spacy.load('en_core_web_sm')", + "convention_df = SampleCorpora.ConventionData2012.get_data()", + "corpus = CorpusFromPandas(convention_df,", + " category_col='party',", + " text_col='text',", + " nlp=nlp).build()", + "", + "html = produce_scattertext_html(corpus,", + " category='democrat',", + " category_name='Democratic',", + " not_category_name='Republican',", + " minimum_term_frequency=5,", + " width_in_pixels=1000)", + "open('./simple.html', 'wb').write(html.encode('utf-8'))", + "print('Open ./simple.html in Chrome or Firefox.')" ], "author": "Jason Kessler", "author_links": { diff --git a/website/src/widgets/landing.js b/website/src/widgets/landing.js index b7ae35f6e..c3aaa8a22 100644 --- a/website/src/widgets/landing.js +++ b/website/src/widgets/landing.js @@ -105,13 +105,13 @@ const Landing = ({ data }) => { - + spaCy Tailored Pipelines