mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-13 05:34:15 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
7ee846b992
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
|
@ -45,6 +45,12 @@ jobs:
|
|||
run: |
|
||||
python -m pip install flake8==5.0.4
|
||||
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
||||
- name: cython-lint
|
||||
run: |
|
||||
python -m pip install cython-lint -c requirements.txt
|
||||
# E501: line too log, W291: trailing whitespace, E266: too many leading '#' for block comment
|
||||
cython-lint spacy --ignore E501,W291,E266
|
||||
|
||||
tests:
|
||||
name: Test
|
||||
needs: Validate
|
||||
|
@ -52,10 +58,8 @@ jobs:
|
|||
fail-fast: true
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python_version: ["3.11"]
|
||||
python_version: ["3.11", "3.12.0-rc.2"]
|
||||
include:
|
||||
- os: ubuntu-20.04
|
||||
python_version: "3.6"
|
||||
- os: windows-latest
|
||||
python_version: "3.7"
|
||||
- os: macos-latest
|
||||
|
@ -89,7 +93,7 @@ jobs:
|
|||
- name: Run mypy
|
||||
run: |
|
||||
python -m mypy spacy
|
||||
if: matrix.python_version != '3.6'
|
||||
if: matrix.python_version != '3.7'
|
||||
|
||||
- name: Delete source directory and .egg-info
|
||||
run: |
|
||||
|
|
74
README.md
74
README.md
|
@ -6,23 +6,20 @@ spaCy is a library for **advanced Natural Language Processing** in Python and
|
|||
Cython. It's built on the very latest research, and was designed from day one to
|
||||
be used in real products.
|
||||
|
||||
spaCy comes with
|
||||
[pretrained pipelines](https://spacy.io/models) and
|
||||
currently supports tokenization and training for **70+ languages**. It features
|
||||
state-of-the-art speed and **neural network models** for tagging,
|
||||
parsing, **named entity recognition**, **text classification** and more,
|
||||
multi-task learning with pretrained **transformers** like BERT, as well as a
|
||||
spaCy comes with [pretrained pipelines](https://spacy.io/models) and currently
|
||||
supports tokenization and training for **70+ languages**. It features
|
||||
state-of-the-art speed and **neural network models** for tagging, parsing,
|
||||
**named entity recognition**, **text classification** and more, multi-task
|
||||
learning with pretrained **transformers** like BERT, as well as a
|
||||
production-ready [**training system**](https://spacy.io/usage/training) and easy
|
||||
model packaging, deployment and workflow management. spaCy is commercial
|
||||
open-source software, released under the [MIT license](https://github.com/explosion/spaCy/blob/master/LICENSE).
|
||||
open-source software, released under the
|
||||
[MIT license](https://github.com/explosion/spaCy/blob/master/LICENSE).
|
||||
|
||||
💥 **We'd love to hear more about your experience with spaCy!**
|
||||
[Fill out our survey here.](https://form.typeform.com/to/aMel9q9f)
|
||||
|
||||
💫 **Version 3.5 out now!**
|
||||
💫 **Version 3.7 out now!**
|
||||
[Check out the release notes here.](https://github.com/explosion/spaCy/releases)
|
||||
|
||||
[](https://dev.azure.com/explosion-ai/public/_build?definitionId=8)
|
||||
[](https://github.com/explosion/spaCy/actions/workflows/tests.yml)
|
||||
[](https://github.com/explosion/spaCy/releases)
|
||||
[](https://pypi.org/project/spacy/)
|
||||
[](https://anaconda.org/conda-forge/spacy)
|
||||
|
@ -35,22 +32,22 @@ open-source software, released under the [MIT license](https://github.com/explos
|
|||
|
||||
## 📖 Documentation
|
||||
|
||||
| Documentation | |
|
||||
| ----------------------------- | ---------------------------------------------------------------------- |
|
||||
| ⭐️ **[spaCy 101]** | New to spaCy? Here's everything you need to know! |
|
||||
| 📚 **[Usage Guides]** | How to use spaCy and its features. |
|
||||
| 🚀 **[New in v3.0]** | New features, backwards incompatibilities and migration guide. |
|
||||
| 🪐 **[Project Templates]** | End-to-end workflows you can clone, modify and run. |
|
||||
| 🎛 **[API Reference]** | The detailed reference for spaCy's API. |
|
||||
| 📦 **[Models]** | Download trained pipelines for spaCy. |
|
||||
| 🌌 **[Universe]** | Plugins, extensions, demos and books from the spaCy ecosystem. |
|
||||
| ⚙️ **[spaCy VS Code Extension]** | Additional tooling and features for working with spaCy's config files. |
|
||||
| 👩🏫 **[Online Course]** | Learn spaCy in this free and interactive online course. |
|
||||
| 📺 **[Videos]** | Our YouTube channel with video tutorials, talks and more. |
|
||||
| 🛠 **[Changelog]** | Changes and version history. |
|
||||
| 💝 **[Contribute]** | How to contribute to the spaCy project and code base. |
|
||||
| <a href="https://explosion.ai/spacy-tailored-pipelines"><img src="https://user-images.githubusercontent.com/13643239/152853098-1c761611-ccb0-4ec6-9066-b234552831fe.png" width="125" alt="spaCy Tailored Pipelines"/></a> | Get a custom spaCy pipeline, tailor-made for your NLP problem by spaCy's core developers. Streamlined, production-ready, predictable and maintainable. Start by completing our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-pipelines)** |
|
||||
| <a href="https://explosion.ai/spacy-tailored-analysis"><img src="https://user-images.githubusercontent.com/1019791/206151300-b00cd189-e503-4797-aa1e-1bb6344062c5.png" width="125" alt="spaCy Tailored Pipelines"/></a> | Bespoke advice for problem solving, strategy and analysis for applied NLP projects. Services include data strategy, code reviews, pipeline design and annotation coaching. Curious? Fill in our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-analysis)** |
|
||||
| Documentation | |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ⭐️ **[spaCy 101]** | New to spaCy? Here's everything you need to know! |
|
||||
| 📚 **[Usage Guides]** | How to use spaCy and its features. |
|
||||
| 🚀 **[New in v3.0]** | New features, backwards incompatibilities and migration guide. |
|
||||
| 🪐 **[Project Templates]** | End-to-end workflows you can clone, modify and run. |
|
||||
| 🎛 **[API Reference]** | The detailed reference for spaCy's API. |
|
||||
| 📦 **[Models]** | Download trained pipelines for spaCy. |
|
||||
| 🌌 **[Universe]** | Plugins, extensions, demos and books from the spaCy ecosystem. |
|
||||
| ⚙️ **[spaCy VS Code Extension]** | Additional tooling and features for working with spaCy's config files. |
|
||||
| 👩🏫 **[Online Course]** | Learn spaCy in this free and interactive online course. |
|
||||
| 📺 **[Videos]** | Our YouTube channel with video tutorials, talks and more. |
|
||||
| 🛠 **[Changelog]** | Changes and version history. |
|
||||
| 💝 **[Contribute]** | How to contribute to the spaCy project and code base. |
|
||||
| <a href="https://explosion.ai/spacy-tailored-pipelines"><img src="https://user-images.githubusercontent.com/13643239/152853098-1c761611-ccb0-4ec6-9066-b234552831fe.png" width="125" alt="spaCy Tailored Pipelines"/></a> | Get a custom spaCy pipeline, tailor-made for your NLP problem by spaCy's core developers. Streamlined, production-ready, predictable and maintainable. Start by completing our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-pipelines)** |
|
||||
| <a href="https://explosion.ai/spacy-tailored-analysis"><img src="https://user-images.githubusercontent.com/1019791/206151300-b00cd189-e503-4797-aa1e-1bb6344062c5.png" width="125" alt="spaCy Tailored Pipelines"/></a> | Bespoke advice for problem solving, strategy and analysis for applied NLP projects. Services include data strategy, code reviews, pipeline design and annotation coaching. Curious? Fill in our 5-minute questionnaire to tell us what you need and we'll be in touch! **[Learn more →](https://explosion.ai/spacy-tailored-analysis)** |
|
||||
|
||||
[spacy 101]: https://spacy.io/usage/spacy-101
|
||||
[new in v3.0]: https://spacy.io/usage/v3
|
||||
|
@ -58,7 +55,7 @@ open-source software, released under the [MIT license](https://github.com/explos
|
|||
[api reference]: https://spacy.io/api/
|
||||
[models]: https://spacy.io/models
|
||||
[universe]: https://spacy.io/universe
|
||||
[spaCy VS Code Extension]: https://github.com/explosion/spacy-vscode
|
||||
[spacy vs code extension]: https://github.com/explosion/spacy-vscode
|
||||
[videos]: https://www.youtube.com/c/ExplosionAI
|
||||
[online course]: https://course.spacy.io
|
||||
[project templates]: https://github.com/explosion/projects
|
||||
|
@ -92,7 +89,9 @@ more people can benefit from it.
|
|||
- State-of-the-art speed
|
||||
- Production-ready **training system**
|
||||
- Linguistically-motivated **tokenization**
|
||||
- Components for named **entity recognition**, part-of-speech-tagging, dependency parsing, sentence segmentation, **text classification**, lemmatization, morphological analysis, entity linking and more
|
||||
- Components for named **entity recognition**, part-of-speech-tagging,
|
||||
dependency parsing, sentence segmentation, **text classification**,
|
||||
lemmatization, morphological analysis, entity linking and more
|
||||
- Easily extensible with **custom components** and attributes
|
||||
- Support for custom models in **PyTorch**, **TensorFlow** and other frameworks
|
||||
- Built in **visualizers** for syntax and NER
|
||||
|
@ -109,7 +108,7 @@ For detailed installation instructions, see the
|
|||
|
||||
- **Operating system**: macOS / OS X · Linux · Windows (Cygwin, MinGW, Visual
|
||||
Studio)
|
||||
- **Python version**: Python 3.6+ (only 64 bit)
|
||||
- **Python version**: Python 3.7+ (only 64 bit)
|
||||
- **Package managers**: [pip] · [conda] (via `conda-forge`)
|
||||
|
||||
[pip]: https://pypi.org/project/spacy/
|
||||
|
@ -118,8 +117,8 @@ For detailed installation instructions, see the
|
|||
### pip
|
||||
|
||||
Using pip, spaCy releases are available as source packages and binary wheels.
|
||||
Before you install spaCy and its dependencies, make sure that
|
||||
your `pip`, `setuptools` and `wheel` are up to date.
|
||||
Before you install spaCy and its dependencies, make sure that your `pip`,
|
||||
`setuptools` and `wheel` are up to date.
|
||||
|
||||
```bash
|
||||
pip install -U pip setuptools wheel
|
||||
|
@ -174,9 +173,9 @@ with the new version.
|
|||
|
||||
## 📦 Download model packages
|
||||
|
||||
Trained pipelines for spaCy can be installed as **Python packages**. This
|
||||
means that they're a component of your application, just like any other module.
|
||||
Models can be installed using spaCy's [`download`](https://spacy.io/api/cli#download)
|
||||
Trained pipelines for spaCy can be installed as **Python packages**. This means
|
||||
that they're a component of your application, just like any other module. Models
|
||||
can be installed using spaCy's [`download`](https://spacy.io/api/cli#download)
|
||||
command, or manually by pointing pip to a path or URL.
|
||||
|
||||
| Documentation | |
|
||||
|
@ -242,8 +241,7 @@ do that depends on your system.
|
|||
| **Mac** | Install a recent version of [XCode](https://developer.apple.com/xcode/), including the so-called "Command Line Tools". macOS and OS X ship with Python and git preinstalled. |
|
||||
| **Windows** | Install a version of the [Visual C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) or [Visual Studio Express](https://visualstudio.microsoft.com/vs/express/) that matches the version that was used to compile your Python interpreter. |
|
||||
|
||||
For more details
|
||||
and instructions, see the documentation on
|
||||
For more details and instructions, see the documentation on
|
||||
[compiling spaCy from source](https://spacy.io/usage#source) and the
|
||||
[quickstart widget](https://spacy.io/usage#section-quickstart) to get the right
|
||||
commands for your platform and Python version.
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
# build version constraints for use with wheelwright + multibuild
|
||||
numpy==1.15.0; python_version<='3.7' and platform_machine!='aarch64'
|
||||
numpy==1.19.2; python_version<='3.7' and platform_machine=='aarch64'
|
||||
# build version constraints for use with wheelwright
|
||||
numpy==1.15.0; python_version=='3.7' and platform_machine!='aarch64'
|
||||
numpy==1.19.2; python_version=='3.7' and platform_machine=='aarch64'
|
||||
numpy==1.17.3; python_version=='3.8' and platform_machine!='aarch64'
|
||||
numpy==1.19.2; python_version=='3.8' and platform_machine=='aarch64'
|
||||
numpy==1.19.3; python_version=='3.9'
|
||||
numpy==1.21.3; python_version=='3.10'
|
||||
numpy==1.23.2; python_version=='3.11'
|
||||
numpy; python_version>='3.12'
|
||||
numpy>=1.25.0; python_version>='3.9'
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
# Listeners
|
||||
|
||||
1. [Overview](#1-overview)
|
||||
2. [Initialization](#2-initialization)
|
||||
- [A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
||||
- [B. Shape inference](#2b-shape-inference)
|
||||
3. [Internal communication](#3-internal-communication)
|
||||
- [A. During prediction](#3a-during-prediction)
|
||||
- [B. During training](#3b-during-training)
|
||||
- [C. Frozen components](#3c-frozen-components)
|
||||
4. [Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
||||
- [1. Overview](#1-overview)
|
||||
- [2. Initialization](#2-initialization)
|
||||
- [2A. Linking listeners to the embedding component](#2a-linking-listeners-to-the-embedding-component)
|
||||
- [2B. Shape inference](#2b-shape-inference)
|
||||
- [3. Internal communication](#3-internal-communication)
|
||||
- [3A. During prediction](#3a-during-prediction)
|
||||
- [3B. During training](#3b-during-training)
|
||||
- [Training with multiple listeners](#training-with-multiple-listeners)
|
||||
- [3C. Frozen components](#3c-frozen-components)
|
||||
- [The Tok2Vec or Transformer is frozen](#the-tok2vec-or-transformer-is-frozen)
|
||||
- [The upstream component is frozen](#the-upstream-component-is-frozen)
|
||||
- [4. Replacing listener with standalone](#4-replacing-listener-with-standalone)
|
||||
|
||||
## 1. Overview
|
||||
|
||||
|
@ -62,7 +65,7 @@ of this `find_listener()` method will specifically identify sublayers of a model
|
|||
|
||||
If it's a Transformer-based pipeline, a
|
||||
[`transformer` component](https://github.com/explosion/spacy-transformers/blob/master/spacy_transformers/pipeline_component.py)
|
||||
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
|
||||
has a similar implementation but its `find_listener()` function will specifically look for `TransformerListener`
|
||||
sublayers of downstream components.
|
||||
|
||||
### 2B. Shape inference
|
||||
|
@ -154,7 +157,7 @@ as a tagger or a parser. This used to be impossible before 3.1, but has become s
|
|||
embedding component in the [`annotating_components`](https://spacy.io/usage/training#annotating-components)
|
||||
list of the config. This works like any other "annotating component" because it relies on the `Doc` attributes.
|
||||
|
||||
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
|
||||
However, if the `Tok2Vec` or `Transformer` is frozen, and not present in `annotating_components`, and a related
|
||||
listener isn't frozen, then a `W086` warning is shown and further training of the pipeline will likely end with `E954`.
|
||||
|
||||
#### The upstream component is frozen
|
||||
|
@ -216,5 +219,17 @@ new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
|||
```
|
||||
|
||||
The new config and model are then properly stored on the `nlp` object.
|
||||
Note that this functionality (running the replacement for a transformer listener) was broken prior to
|
||||
Note that this functionality (running the replacement for a transformer listener) was broken prior to
|
||||
`spacy-transformers` 1.0.5.
|
||||
|
||||
In spaCy 3.7, `Language.replace_listeners` was updated to pass the following additional arguments to the `replace_listener` callback:
|
||||
the listener to be replaced and the `tok2vec`/`transformer` pipe from which the new model was copied. To maintain backwards-compatiblity,
|
||||
the method only passes these extra arguments for callbacks that support them:
|
||||
|
||||
```
|
||||
def replace_listener_pre_37(copied_tok2vec_model):
|
||||
...
|
||||
|
||||
def replace_listener_post_37(copied_tok2vec_model, replaced_listener, tok2vec_pipe):
|
||||
...
|
||||
```
|
||||
|
|
|
@ -5,8 +5,9 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.1.8,<8.2.0",
|
||||
"numpy>=1.15.0",
|
||||
"thinc>=8.1.8,<8.3.0",
|
||||
"numpy>=1.15.0; python_version < '3.9'",
|
||||
"numpy>=1.25.0; python_version >= '3.9'",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ spacy-legacy>=3.0.11,<3.1.0
|
|||
spacy-loggers>=1.0.0,<2.0.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
thinc>=8.1.8,<8.3.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
|
@ -12,11 +12,13 @@ catalogue>=2.0.6,<2.1.0
|
|||
typer>=0.3.0,<0.10.0
|
||||
pathy>=0.10.0
|
||||
smart-open>=5.2.1,<7.0.0
|
||||
weasel>=0.1.0,<0.4.0
|
||||
# Third party dependencies
|
||||
numpy>=1.15.0
|
||||
numpy>=1.15.0; python_version < "3.9"
|
||||
numpy>=1.19.0; python_version >= "3.9"
|
||||
requests>=2.13.0,<3.0.0
|
||||
tqdm>=4.38.0,<5.0.0
|
||||
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
|
||||
pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0
|
||||
jinja2
|
||||
langcodes>=3.2.0,<4.0.0
|
||||
# Official Python utilities
|
||||
|
@ -31,11 +33,11 @@ pytest-timeout>=1.3.0,<2.0.0
|
|||
mock>=2.0.0,<3.0.0
|
||||
flake8>=3.8.0,<6.0.0
|
||||
hypothesis>=3.27.0,<7.0.0
|
||||
mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7"
|
||||
types-dataclasses>=0.1.3; python_version < "3.7"
|
||||
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
|
||||
types-mock>=0.1.1
|
||||
types-setuptools>=57.0.0
|
||||
types-requests
|
||||
types-setuptools>=57.0.0
|
||||
black==22.3.0
|
||||
cython-lint>=0.15.0
|
||||
isort>=5.0,<6.0
|
||||
|
|
24
setup.cfg
24
setup.cfg
|
@ -17,7 +17,6 @@ classifiers =
|
|||
Operating System :: Microsoft :: Windows
|
||||
Programming Language :: Cython
|
||||
Programming Language :: Python :: 3
|
||||
Programming Language :: Python :: 3.6
|
||||
Programming Language :: Python :: 3.7
|
||||
Programming Language :: Python :: 3.8
|
||||
Programming Language :: Python :: 3.9
|
||||
|
@ -31,15 +30,18 @@ project_urls =
|
|||
[options]
|
||||
zip_safe = false
|
||||
include_package_data = true
|
||||
python_requires = >=3.6
|
||||
python_requires = >=3.7
|
||||
# NOTE: This section is superseded by pyproject.toml and will be removed in
|
||||
# spaCy v4
|
||||
setup_requires =
|
||||
cython>=0.25,<3.0
|
||||
numpy>=1.15.0
|
||||
numpy>=1.15.0; python_version < "3.9"
|
||||
numpy>=1.19.0; python_version >= "3.9"
|
||||
# We also need our Cython packages here to compile against
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
thinc>=8.1.8,<8.3.0
|
||||
install_requires =
|
||||
# Our libraries
|
||||
spacy-legacy>=3.0.11,<3.1.0
|
||||
|
@ -47,18 +49,20 @@ install_requires =
|
|||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
thinc>=8.1.8,<8.3.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
srsly>=2.4.3,<3.0.0
|
||||
catalogue>=2.0.6,<2.1.0
|
||||
weasel>=0.1.0,<0.4.0
|
||||
# Third-party dependencies
|
||||
typer>=0.3.0,<0.10.0
|
||||
pathy>=0.10.0
|
||||
smart-open>=5.2.1,<7.0.0
|
||||
tqdm>=4.38.0,<5.0.0
|
||||
numpy>=1.15.0
|
||||
numpy>=1.15.0; python_version < "3.9"
|
||||
numpy>=1.19.0; python_version >= "3.9"
|
||||
requests>=2.13.0,<3.0.0
|
||||
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
|
||||
pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0
|
||||
jinja2
|
||||
# Official Python utilities
|
||||
setuptools
|
||||
|
@ -74,9 +78,7 @@ console_scripts =
|
|||
lookups =
|
||||
spacy_lookups_data>=1.0.3,<1.1.0
|
||||
transformers =
|
||||
spacy_transformers>=1.1.2,<1.3.0
|
||||
ray =
|
||||
spacy_ray>=0.1.0,<1.0.0
|
||||
spacy_transformers>=1.1.2,<1.4.0
|
||||
cuda =
|
||||
cupy>=5.0.0b4,<13.0.0
|
||||
cuda80 =
|
||||
|
@ -111,6 +113,8 @@ cuda117 =
|
|||
cupy-cuda117>=5.0.0b4,<13.0.0
|
||||
cuda11x =
|
||||
cupy-cuda11x>=11.0.0,<13.0.0
|
||||
cuda12x =
|
||||
cupy-cuda12x>=11.5.0,<13.0.0
|
||||
cuda-autodetect =
|
||||
cupy-wheel>=11.0.0,<13.0.0
|
||||
apple =
|
||||
|
|
32
setup.py
32
setup.py
|
@ -1,10 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
from setuptools import Extension, setup, find_packages
|
||||
import sys
|
||||
import platform
|
||||
import numpy
|
||||
from distutils.command.build_ext import build_ext
|
||||
from distutils.sysconfig import get_python_inc
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from sysconfig import get_path
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from Cython.Build import cythonize
|
||||
|
@ -79,6 +78,7 @@ COMPILER_DIRECTIVES = {
|
|||
"language_level": -3,
|
||||
"embedsignature": True,
|
||||
"annotation_typing": False,
|
||||
"profile": sys.version_info < (3, 12),
|
||||
}
|
||||
# Files to copy into the package that are otherwise not included
|
||||
COPY_FILES = {
|
||||
|
@ -88,30 +88,6 @@ COPY_FILES = {
|
|||
}
|
||||
|
||||
|
||||
def is_new_osx():
|
||||
"""Check whether we're on OSX >= 10.7"""
|
||||
if sys.platform != "darwin":
|
||||
return False
|
||||
mac_ver = platform.mac_ver()[0]
|
||||
if mac_ver.startswith("10"):
|
||||
minor_version = int(mac_ver.split(".")[1])
|
||||
if minor_version >= 7:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
if is_new_osx():
|
||||
# On Mac, use libc++ because Apple deprecated use of
|
||||
# libstdc
|
||||
COMPILE_OPTIONS["other"].append("-stdlib=libc++")
|
||||
LINK_OPTIONS["other"].append("-lc++")
|
||||
# g++ (used by unix compiler on mac) links to libstdc++ as a default lib.
|
||||
# See: https://stackoverflow.com/questions/1653047/avoid-linking-to-libstdc
|
||||
LINK_OPTIONS["other"].append("-nodefaultlibs")
|
||||
|
||||
|
||||
# By subclassing build_extensions we have the actual compiler that will be used which is really known only after finalize_options
|
||||
# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
|
||||
class build_ext_options:
|
||||
|
@ -204,7 +180,7 @@ def setup_package():
|
|||
|
||||
include_dirs = [
|
||||
numpy.get_include(),
|
||||
get_python_inc(plat_specific=True),
|
||||
get_path("include"),
|
||||
]
|
||||
ext_modules = []
|
||||
ext_modules.append(
|
||||
|
|
|
@ -13,7 +13,6 @@ from thinc.api import Config, prefer_gpu, require_cpu, require_gpu # noqa: F401
|
|||
from . import pipeline # noqa: F401
|
||||
from . import util
|
||||
from .about import __version__ # noqa: F401
|
||||
from .cli.info import info # noqa: F401
|
||||
from .errors import Errors
|
||||
from .glossary import explain # noqa: F401
|
||||
from .language import Language
|
||||
|
@ -77,3 +76,9 @@ def blank(
|
|||
# We should accept both dot notation and nested dict here for consistency
|
||||
config = util.dot_to_dict(config)
|
||||
return LangClass.from_config(config, vocab=vocab, meta=meta)
|
||||
|
||||
|
||||
def info(*args, **kwargs):
|
||||
from .cli.info import info as cli_info
|
||||
|
||||
return cli_info(*args, **kwargs)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
# fmt: off
|
||||
__title__ = "spacy"
|
||||
__version__ = "3.6.0"
|
||||
__version__ = "3.7.0"
|
||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||
__projects__ = "https://github.com/explosion/projects"
|
||||
__projects_branch__ = "v3"
|
||||
|
|
|
@ -96,4 +96,4 @@ cdef enum attr_id_t:
|
|||
ENT_ID = symbols.ENT_ID
|
||||
|
||||
IDX
|
||||
SENT_END
|
||||
SENT_END
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# cython: profile=False
|
||||
from .errors import Errors
|
||||
|
||||
IOB_STRINGS = ("", "I", "O", "B")
|
||||
|
@ -117,7 +118,7 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
|
|||
if "pos" in stringy_attrs:
|
||||
stringy_attrs["TAG"] = stringy_attrs.pop("pos")
|
||||
if "morph" in stringy_attrs:
|
||||
morphs = stringy_attrs.pop("morph")
|
||||
morphs = stringy_attrs.pop("morph") # no-cython-lint
|
||||
if "number" in stringy_attrs:
|
||||
stringy_attrs.pop("number")
|
||||
if "tenspect" in stringy_attrs:
|
||||
|
|
|
@ -14,6 +14,7 @@ from .debug_diff import debug_diff # noqa: F401
|
|||
from .debug_model import debug_model # noqa: F401
|
||||
from .download import download # noqa: F401
|
||||
from .evaluate import evaluate # noqa: F401
|
||||
from .find_function import find_function # noqa: F401
|
||||
from .find_threshold import find_threshold # noqa: F401
|
||||
from .info import info # noqa: F401
|
||||
from .init_config import fill_config, init_config # noqa: F401
|
||||
|
@ -21,13 +22,6 @@ from .init_pipeline import init_pipeline_cli # noqa: F401
|
|||
from .package import package # noqa: F401
|
||||
from .pretrain import pretrain # noqa: F401
|
||||
from .profile import profile # noqa: F401
|
||||
from .project.assets import project_assets # noqa: F401
|
||||
from .project.clone import project_clone # noqa: F401
|
||||
from .project.document import project_document # noqa: F401
|
||||
from .project.dvc import project_update_dvc # noqa: F401
|
||||
from .project.pull import project_pull # noqa: F401
|
||||
from .project.push import project_push # noqa: F401
|
||||
from .project.run import project_run # noqa: F401
|
||||
from .train import train_cli # noqa: F401
|
||||
from .validate import validate # noqa: F401
|
||||
|
||||
|
|
|
@ -25,10 +25,11 @@ from thinc.api import Config, ConfigValidationError, require_gpu
|
|||
from thinc.util import gpu_is_available
|
||||
from typer.main import get_command
|
||||
from wasabi import Printer, msg
|
||||
from weasel import app as project_cli
|
||||
|
||||
from .. import about
|
||||
from ..compat import Literal
|
||||
from ..schemas import ProjectConfigSchema, validate
|
||||
from ..schemas import validate
|
||||
from ..util import (
|
||||
ENV_VARS,
|
||||
SimpleFrozenDict,
|
||||
|
@ -48,7 +49,6 @@ SDIST_SUFFIX = ".tar.gz"
|
|||
WHEEL_SUFFIX = "-py3-none-any.whl"
|
||||
|
||||
PROJECT_FILE = "project.yml"
|
||||
PROJECT_LOCK = "project.lock"
|
||||
COMMAND = "python -m spacy"
|
||||
NAME = "spacy"
|
||||
HELP = """spaCy Command-line Interface
|
||||
|
@ -74,11 +74,10 @@ Opt = typer.Option
|
|||
|
||||
app = typer.Typer(name=NAME, help=HELP)
|
||||
benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True)
|
||||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
||||
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
||||
|
||||
app.add_typer(project_cli)
|
||||
app.add_typer(project_cli, name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
app.add_typer(debug_cli)
|
||||
app.add_typer(benchmark_cli)
|
||||
app.add_typer(init_cli)
|
||||
|
@ -153,148 +152,6 @@ def _parse_override(value: Any) -> Any:
|
|||
return str(value)
|
||||
|
||||
|
||||
def load_project_config(
|
||||
path: Path, interpolate: bool = True, overrides: Dict[str, Any] = SimpleFrozenDict()
|
||||
) -> Dict[str, Any]:
|
||||
"""Load the project.yml file from a directory and validate it. Also make
|
||||
sure that all directories defined in the config exist.
|
||||
|
||||
path (Path): The path to the project directory.
|
||||
interpolate (bool): Whether to substitute project variables.
|
||||
overrides (Dict[str, Any]): Optional config overrides.
|
||||
RETURNS (Dict[str, Any]): The loaded project.yml.
|
||||
"""
|
||||
config_path = path / PROJECT_FILE
|
||||
if not config_path.exists():
|
||||
msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
|
||||
invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
|
||||
try:
|
||||
config = srsly.read_yaml(config_path)
|
||||
except ValueError as e:
|
||||
msg.fail(invalid_err, e, exits=1)
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
if errors:
|
||||
msg.fail(invalid_err)
|
||||
print("\n".join(errors))
|
||||
sys.exit(1)
|
||||
validate_project_version(config)
|
||||
validate_project_commands(config)
|
||||
if interpolate:
|
||||
err = f"{PROJECT_FILE} validation error"
|
||||
with show_validation_error(title=err, hint_fill=False):
|
||||
config = substitute_project_variables(config, overrides)
|
||||
# Make sure directories defined in config exist
|
||||
for subdir in config.get("directories", []):
|
||||
dir_path = path / subdir
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True)
|
||||
return config
|
||||
|
||||
|
||||
def substitute_project_variables(
|
||||
config: Dict[str, Any],
|
||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
||||
key: str = "vars",
|
||||
env_key: str = "env",
|
||||
) -> Dict[str, Any]:
|
||||
"""Interpolate variables in the project file using the config system.
|
||||
|
||||
config (Dict[str, Any]): The project config.
|
||||
overrides (Dict[str, Any]): Optional config overrides.
|
||||
key (str): Key containing variables in project config.
|
||||
env_key (str): Key containing environment variable mapping in project config.
|
||||
RETURNS (Dict[str, Any]): The interpolated project config.
|
||||
"""
|
||||
config.setdefault(key, {})
|
||||
config.setdefault(env_key, {})
|
||||
# Substitute references to env vars with their values
|
||||
for config_var, env_var in config[env_key].items():
|
||||
config[env_key][config_var] = _parse_override(os.environ.get(env_var, ""))
|
||||
# Need to put variables in the top scope again so we can have a top-level
|
||||
# section "project" (otherwise, a list of commands in the top scope wouldn't)
|
||||
# be allowed by Thinc's config system
|
||||
cfg = Config({"project": config, key: config[key], env_key: config[env_key]})
|
||||
cfg = Config().from_str(cfg.to_str(), overrides=overrides)
|
||||
interpolated = cfg.interpolate()
|
||||
return dict(interpolated["project"])
|
||||
|
||||
|
||||
def validate_project_version(config: Dict[str, Any]) -> None:
|
||||
"""If the project defines a compatible spaCy version range, chec that it's
|
||||
compatible with the current version of spaCy.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
spacy_version = config.get("spacy_version", None)
|
||||
if spacy_version and not is_compatible_version(about.__version__, spacy_version):
|
||||
err = (
|
||||
f"The {PROJECT_FILE} specifies a spaCy version range ({spacy_version}) "
|
||||
f"that's not compatible with the version of spaCy you're running "
|
||||
f"({about.__version__}). You can edit version requirement in the "
|
||||
f"{PROJECT_FILE} to load it, but the project may not run as expected."
|
||||
)
|
||||
msg.fail(err, exits=1)
|
||||
|
||||
|
||||
def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||
"""Check that project commands and workflows are valid, don't contain
|
||||
duplicates, don't clash and only refer to commands that exist.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_steps in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step in workflow_steps:
|
||||
if step not in command_names:
|
||||
msg.fail(
|
||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||
f"Workflows can only refer to commands defined in the 'commands' "
|
||||
f"section of the {PROJECT_FILE}.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
||||
"""Get the hash for a JSON-serializable object.
|
||||
|
||||
data: The data to hash.
|
||||
exclude (Iterable[str]): Top-level keys to exclude if data is a dict.
|
||||
RETURNS (str): The hash.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
data = {k: v for k, v in data.items() if k not in exclude}
|
||||
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
||||
return hashlib.md5(data_str).hexdigest()
|
||||
|
||||
|
||||
def get_checksum(path: Union[Path, str]) -> str:
|
||||
"""Get the checksum for a file or directory given its file path. If a
|
||||
directory path is provided, this uses all files in that directory.
|
||||
|
||||
path (Union[Path, str]): The file or directory path.
|
||||
RETURNS (str): The checksum.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not (path.is_file() or path.is_dir()):
|
||||
msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
|
||||
if path.is_file():
|
||||
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
||||
else:
|
||||
# TODO: this is currently pretty slow
|
||||
dir_checksum = hashlib.md5()
|
||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
||||
dir_checksum.update(sub_file.read_bytes())
|
||||
return dir_checksum.hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def show_validation_error(
|
||||
file_path: Optional[Union[str, Path]] = None,
|
||||
|
@ -352,166 +209,10 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
|||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
|
||||
|
||||
def upload_file(src: Path, dest: Union[str, "FluidPath"]) -> None:
|
||||
"""Upload a file.
|
||||
|
||||
src (Path): The source path.
|
||||
url (str): The destination URL to upload to.
|
||||
"""
|
||||
import smart_open
|
||||
|
||||
# Create parent directories for local paths
|
||||
if isinstance(dest, Path):
|
||||
if not dest.parent.exists():
|
||||
dest.parent.mkdir(parents=True)
|
||||
|
||||
dest = str(dest)
|
||||
with smart_open.open(dest, mode="wb") as output_file:
|
||||
with src.open(mode="rb") as input_file:
|
||||
output_file.write(input_file.read())
|
||||
|
||||
|
||||
def download_file(
|
||||
src: Union[str, "FluidPath"], dest: Path, *, force: bool = False
|
||||
) -> None:
|
||||
"""Download a file using smart_open.
|
||||
|
||||
url (str): The URL of the file.
|
||||
dest (Path): The destination path.
|
||||
force (bool): Whether to force download even if file exists.
|
||||
If False, the download will be skipped.
|
||||
"""
|
||||
import smart_open
|
||||
|
||||
if dest.exists() and not force:
|
||||
return None
|
||||
src = str(src)
|
||||
with smart_open.open(src, mode="rb", compression="disable") as input_file:
|
||||
with dest.open(mode="wb") as output_file:
|
||||
shutil.copyfileobj(input_file, output_file)
|
||||
|
||||
|
||||
def ensure_pathy(path):
|
||||
"""Temporary helper to prevent importing Pathy globally (which can cause
|
||||
slow and annoying Google Cloud warning)."""
|
||||
from pathy import Pathy # noqa: F811
|
||||
|
||||
return Pathy.fluid(path)
|
||||
|
||||
|
||||
def git_checkout(
|
||||
repo: str, subpath: str, dest: Path, *, branch: str = "master", sparse: bool = False
|
||||
):
|
||||
git_version = get_git_version()
|
||||
if dest.exists():
|
||||
msg.fail("Destination of checkout must not exist", exits=1)
|
||||
if not dest.parent.exists():
|
||||
msg.fail("Parent of destination of checkout must exist", exits=1)
|
||||
if sparse and git_version >= (2, 22):
|
||||
return git_sparse_checkout(repo, subpath, dest, branch)
|
||||
elif sparse:
|
||||
# Only show warnings if the user explicitly wants sparse checkout but
|
||||
# the Git version doesn't support it
|
||||
err_old = (
|
||||
f"You're running an old version of Git (v{git_version[0]}.{git_version[1]}) "
|
||||
f"that doesn't fully support sparse checkout yet."
|
||||
)
|
||||
err_unk = "You're running an unknown version of Git, so sparse checkout has been disabled."
|
||||
msg.warn(
|
||||
f"{err_unk if git_version == (0, 0) else err_old} "
|
||||
f"This means that more files than necessary may be downloaded "
|
||||
f"temporarily. To only download the files needed, make sure "
|
||||
f"you're using Git v2.22 or above."
|
||||
)
|
||||
with make_tempdir() as tmp_dir:
|
||||
cmd = f"git -C {tmp_dir} clone {repo} . -b {branch}"
|
||||
run_command(cmd, capture=True)
|
||||
# We need Path(name) to make sure we also support subdirectories
|
||||
try:
|
||||
source_path = tmp_dir / Path(subpath)
|
||||
if not is_subpath_of(tmp_dir, source_path):
|
||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
||||
msg.fail(err, repo, exits=1)
|
||||
shutil.copytree(str(source_path), str(dest))
|
||||
except FileNotFoundError:
|
||||
err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')"
|
||||
msg.fail(err, repo, exits=1)
|
||||
|
||||
|
||||
def git_sparse_checkout(repo, subpath, dest, branch):
|
||||
# We're using Git, partial clone and sparse checkout to
|
||||
# only clone the files we need
|
||||
# This ends up being RIDICULOUS. omg.
|
||||
# So, every tutorial and SO post talks about 'sparse checkout'...But they
|
||||
# go and *clone* the whole repo. Worthless. And cloning part of a repo
|
||||
# turns out to be completely broken. The only way to specify a "path" is..
|
||||
# a path *on the server*? The contents of which, specifies the paths. Wat.
|
||||
# Obviously this is hopelessly broken and insecure, because you can query
|
||||
# arbitrary paths on the server! So nobody enables this.
|
||||
# What we have to do is disable *all* files. We could then just checkout
|
||||
# the path, and it'd "work", but be hopelessly slow...Because it goes and
|
||||
# transfers every missing object one-by-one. So the final piece is that we
|
||||
# need to use some weird git internals to fetch the missings in bulk, and
|
||||
# *that* we can do by path.
|
||||
# We're using Git and sparse checkout to only clone the files we need
|
||||
with make_tempdir() as tmp_dir:
|
||||
# This is the "clone, but don't download anything" part.
|
||||
cmd = (
|
||||
f"git clone {repo} {tmp_dir} --no-checkout --depth 1 "
|
||||
f"-b {branch} --filter=blob:none"
|
||||
)
|
||||
run_command(cmd)
|
||||
# Now we need to find the missing filenames for the subpath we want.
|
||||
# Looking for this 'rev-list' command in the git --help? Hah.
|
||||
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
|
||||
ret = run_command(cmd, capture=True)
|
||||
git_repo = _http_to_git(repo)
|
||||
# Now pass those missings into another bit of git internals
|
||||
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
|
||||
if not missings:
|
||||
err = (
|
||||
f"Could not find any relevant files for '{subpath}'. "
|
||||
f"Did you specify a correct and complete path within repo '{repo}' "
|
||||
f"and branch {branch}?"
|
||||
)
|
||||
msg.fail(err, exits=1)
|
||||
cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}"
|
||||
run_command(cmd, capture=True)
|
||||
# And finally, we can checkout our subpath
|
||||
cmd = f"git -C {tmp_dir} checkout {branch} {subpath}"
|
||||
run_command(cmd, capture=True)
|
||||
|
||||
# Get a subdirectory of the cloned path, if appropriate
|
||||
source_path = tmp_dir / Path(subpath)
|
||||
if not is_subpath_of(tmp_dir, source_path):
|
||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
||||
msg.fail(err, repo, exits=1)
|
||||
|
||||
shutil.move(str(source_path), str(dest))
|
||||
|
||||
|
||||
def git_repo_branch_exists(repo: str, branch: str) -> bool:
|
||||
"""Uses 'git ls-remote' to check if a repository and branch exists
|
||||
|
||||
repo (str): URL to get repo.
|
||||
branch (str): Branch on repo to check.
|
||||
RETURNS (bool): True if repo:branch exists.
|
||||
"""
|
||||
get_git_version()
|
||||
cmd = f"git ls-remote {repo} {branch}"
|
||||
# We might be tempted to use `--exit-code` with `git ls-remote`, but
|
||||
# `run_command` handles the `returncode` for us, so we'll rely on
|
||||
# the fact that stdout returns '' if the requested branch doesn't exist
|
||||
ret = run_command(cmd, capture=True)
|
||||
exists = ret.stdout != ""
|
||||
return exists
|
||||
|
||||
|
||||
def get_git_version(
|
||||
error: str = "Could not run 'git'. Make sure it's installed and the executable is available.",
|
||||
) -> Tuple[int, int]:
|
||||
"""Get the version of git and raise an error if calling 'git --version' fails.
|
||||
|
||||
error (str): The error message to show.
|
||||
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
|
||||
(0, 0) if the version couldn't be determined.
|
||||
|
@ -527,30 +228,6 @@ def get_git_version(
|
|||
return int(version[0]), int(version[1])
|
||||
|
||||
|
||||
def _http_to_git(repo: str) -> str:
|
||||
if repo.startswith("http://"):
|
||||
repo = repo.replace(r"http://", r"https://")
|
||||
if repo.startswith(r"https://"):
|
||||
repo = repo.replace("https://", "git@").replace("/", ":", 1)
|
||||
if repo.endswith("/"):
|
||||
repo = repo[:-1]
|
||||
repo = f"{repo}.git"
|
||||
return repo
|
||||
|
||||
|
||||
def is_subpath_of(parent, child):
|
||||
"""
|
||||
Check whether `child` is a path contained within `parent`.
|
||||
"""
|
||||
# Based on https://stackoverflow.com/a/37095733 .
|
||||
|
||||
# In Python 3.9, the `Path.is_relative_to()` method will supplant this, so
|
||||
# we can stop using crusty old os.path functions.
|
||||
parent_realpath = os.path.realpath(parent)
|
||||
child_realpath = os.path.realpath(child)
|
||||
return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath
|
||||
|
||||
|
||||
@overload
|
||||
def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
|
||||
...
|
||||
|
|
|
@ -133,7 +133,9 @@ def apply(
|
|||
if len(text_files) > 0:
|
||||
streams.append(_stream_texts(text_files))
|
||||
datagen = cast(DocOrStrStream, chain(*streams))
|
||||
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
||||
for doc in tqdm.tqdm(
|
||||
nlp.pipe(datagen, batch_size=batch_size, n_process=n_process), disable=None
|
||||
):
|
||||
docbin.add(doc)
|
||||
if output_file.suffix == "":
|
||||
output_file = output_file.with_suffix(".spacy")
|
||||
|
|
|
@ -40,7 +40,8 @@ def assemble_cli(
|
|||
|
||||
DOCS: https://spacy.io/api/cli#assemble
|
||||
"""
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or (str(config_path) != "-" and not config_path.exists()):
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
|
|
|
@ -89,7 +89,7 @@ class Quartiles:
|
|||
def annotate(
|
||||
nlp: Language, docs: List[Doc], batch_size: Optional[int]
|
||||
) -> numpy.ndarray:
|
||||
docs = nlp.pipe(tqdm(docs, unit="doc"), batch_size=batch_size)
|
||||
docs = nlp.pipe(tqdm(docs, unit="doc", disable=None), batch_size=batch_size)
|
||||
wps = []
|
||||
while True:
|
||||
with time_context() as elapsed:
|
||||
|
|
|
@ -28,6 +28,7 @@ def evaluate_cli(
|
|||
displacy_path: Optional[Path] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML", exists=True, file_okay=False),
|
||||
displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"),
|
||||
per_component: bool = Opt(False, "--per-component", "-P", help="Return scores per component, only applicable when an output JSON file is specified."),
|
||||
spans_key: str = Opt("sc", "--spans-key", "-sk", help="Spans key to use when evaluating Doc.spans"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
|
@ -53,6 +54,7 @@ def evaluate_cli(
|
|||
displacy_limit=displacy_limit,
|
||||
per_component=per_component,
|
||||
silent=False,
|
||||
spans_key=spans_key,
|
||||
)
|
||||
|
||||
|
||||
|
|
69
spacy/cli/find_function.py
Normal file
69
spacy/cli/find_function.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from catalogue import RegistryError
|
||||
from wasabi import msg
|
||||
|
||||
from ..util import registry
|
||||
from ._util import Arg, Opt, app
|
||||
|
||||
|
||||
@app.command("find-function")
|
||||
def find_function_cli(
|
||||
# fmt: off
|
||||
func_name: str = Arg(..., help="Name of the registered function."),
|
||||
registry_name: Optional[str] = Opt(None, "--registry", "-r", help="Name of the catalogue registry."),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Find the module, path and line number to the file the registered
|
||||
function is defined in, if available.
|
||||
|
||||
func_name (str): Name of the registered function.
|
||||
registry_name (Optional[str]): Name of the catalogue registry.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#find-function
|
||||
"""
|
||||
if not registry_name:
|
||||
registry_names = registry.get_registry_names()
|
||||
for name in registry_names:
|
||||
if registry.has(name, func_name):
|
||||
registry_name = name
|
||||
break
|
||||
|
||||
if not registry_name:
|
||||
msg.fail(
|
||||
f"Couldn't find registered function: '{func_name}'",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
assert registry_name is not None
|
||||
find_function(func_name, registry_name)
|
||||
|
||||
|
||||
def find_function(func_name: str, registry_name: str) -> Tuple[str, int]:
|
||||
registry_desc = None
|
||||
try:
|
||||
registry_desc = registry.find(registry_name, func_name)
|
||||
except RegistryError as e:
|
||||
msg.fail(
|
||||
f"Couldn't find registered function: '{func_name}' in registry '{registry_name}'",
|
||||
)
|
||||
msg.fail(f"{e}", exits=1)
|
||||
assert registry_desc is not None
|
||||
|
||||
registry_path = None
|
||||
line_no = None
|
||||
if registry_desc["file"]:
|
||||
registry_path = registry_desc["file"]
|
||||
line_no = registry_desc["line_no"]
|
||||
|
||||
if not registry_path or not line_no:
|
||||
msg.fail(
|
||||
f"Couldn't find path to registered function: '{func_name}' in registry '{registry_name}'",
|
||||
exits=1,
|
||||
)
|
||||
assert registry_path is not None
|
||||
assert line_no is not None
|
||||
|
||||
msg.good(f"Found registered function '{func_name}' at {registry_path}:{line_no}")
|
||||
return str(registry_path), int(line_no)
|
|
@ -52,8 +52,8 @@ def find_threshold_cli(
|
|||
|
||||
DOCS: https://spacy.io/api/cli#find-threshold
|
||||
"""
|
||||
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
import_code(code_path)
|
||||
find_threshold(
|
||||
model=model,
|
||||
|
|
|
@ -39,7 +39,8 @@ def init_vectors_cli(
|
|||
you can use in the [initialize] block of your config to initialize
|
||||
a model with vectors.
|
||||
"""
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
msg.info(f"Creating blank nlp object for language '{lang}'")
|
||||
nlp = util.get_lang_class(lang)()
|
||||
if jsonl_loc is not None:
|
||||
|
@ -87,7 +88,8 @@ def init_pipeline_cli(
|
|||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
|
||||
# fmt: on
|
||||
):
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
setup_gpu(use_gpu)
|
||||
|
@ -116,7 +118,8 @@ def init_labels_cli(
|
|||
"""Generate JSON files for the labels in the data. This helps speed up the
|
||||
training process, since spaCy won't have to preprocess the data to
|
||||
extract the labels."""
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
if not output_path.exists():
|
||||
output_path.mkdir(parents=True)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
|
|
|
@ -403,7 +403,7 @@ def _format_sources(data: Any) -> str:
|
|||
if author:
|
||||
result += " ({})".format(author)
|
||||
sources.append(result)
|
||||
return "<br />".join(sources)
|
||||
return "<br>".join(sources)
|
||||
|
||||
|
||||
def _format_accuracy(data: Dict[str, Any], exclude: List[str] = ["speed"]) -> str:
|
||||
|
|
|
@ -71,7 +71,7 @@ def profile(model: str, inputs: Optional[Path] = None, n_texts: int = 10000) ->
|
|||
|
||||
|
||||
def parse_texts(nlp: Language, texts: Sequence[str]) -> None:
|
||||
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16):
|
||||
for doc in nlp.pipe(tqdm.tqdm(texts, disable=None), batch_size=16):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1,217 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import typer
|
||||
from wasabi import msg
|
||||
|
||||
from ...util import ensure_path, working_dir
|
||||
from .._util import (
|
||||
PROJECT_FILE,
|
||||
Arg,
|
||||
Opt,
|
||||
SimpleFrozenDict,
|
||||
download_file,
|
||||
get_checksum,
|
||||
get_git_version,
|
||||
git_checkout,
|
||||
load_project_config,
|
||||
parse_config_overrides,
|
||||
project_cli,
|
||||
)
|
||||
|
||||
# Whether assets are extra if `extra` is not set.
|
||||
EXTRA_DEFAULT = False
|
||||
|
||||
|
||||
@project_cli.command(
|
||||
"assets",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def project_assets_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
project_dir: Path = Arg(Path.cwd(), help="Path to cloned project. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
sparse_checkout: bool = Opt(False, "--sparse", "-S", help="Use sparse checkout for assets provided via Git, to only check out and clone the files needed. Requires Git v22.2+."),
|
||||
extra: bool = Opt(False, "--extra", "-e", help="Download all assets, including those marked as 'extra'.")
|
||||
# fmt: on
|
||||
):
|
||||
"""Fetch project assets like datasets and pretrained weights. Assets are
|
||||
defined in the "assets" section of the project.yml. If a checksum is
|
||||
provided in the project.yml, the file is only downloaded if no local file
|
||||
with the same checksum exists.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-assets
|
||||
"""
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
project_assets(
|
||||
project_dir,
|
||||
overrides=overrides,
|
||||
sparse_checkout=sparse_checkout,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def project_assets(
|
||||
project_dir: Path,
|
||||
*,
|
||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
||||
sparse_checkout: bool = False,
|
||||
extra: bool = False,
|
||||
) -> None:
|
||||
"""Fetch assets for a project using DVC if possible.
|
||||
|
||||
project_dir (Path): Path to project directory.
|
||||
sparse_checkout (bool): Use sparse checkout for assets provided via Git, to only check out and clone the files
|
||||
needed.
|
||||
extra (bool): Whether to download all assets, including those marked as 'extra'.
|
||||
"""
|
||||
project_path = ensure_path(project_dir)
|
||||
config = load_project_config(project_path, overrides=overrides)
|
||||
assets = [
|
||||
asset
|
||||
for asset in config.get("assets", [])
|
||||
if extra or not asset.get("extra", EXTRA_DEFAULT)
|
||||
]
|
||||
if not assets:
|
||||
msg.warn(
|
||||
f"No assets specified in {PROJECT_FILE} (if assets are marked as extra, download them with --extra)",
|
||||
exits=0,
|
||||
)
|
||||
msg.info(f"Fetching {len(assets)} asset(s)")
|
||||
|
||||
for asset in assets:
|
||||
dest = (project_dir / asset["dest"]).resolve()
|
||||
checksum = asset.get("checksum")
|
||||
if "git" in asset:
|
||||
git_err = (
|
||||
f"Cloning spaCy project templates requires Git and the 'git' command. "
|
||||
f"Make sure it's installed and that the executable is available."
|
||||
)
|
||||
get_git_version(error=git_err)
|
||||
if dest.exists():
|
||||
# If there's already a file, check for checksum
|
||||
if checksum and checksum == get_checksum(dest):
|
||||
msg.good(
|
||||
f"Skipping download with matching checksum: {asset['dest']}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
if dest.is_dir():
|
||||
shutil.rmtree(dest)
|
||||
else:
|
||||
dest.unlink()
|
||||
if "repo" not in asset["git"] or asset["git"]["repo"] is None:
|
||||
msg.fail(
|
||||
"A git asset must include 'repo', the repository address.", exits=1
|
||||
)
|
||||
if "path" not in asset["git"] or asset["git"]["path"] is None:
|
||||
msg.fail(
|
||||
"A git asset must include 'path' - use \"\" to get the entire repository.",
|
||||
exits=1,
|
||||
)
|
||||
git_checkout(
|
||||
asset["git"]["repo"],
|
||||
asset["git"]["path"],
|
||||
dest,
|
||||
branch=asset["git"].get("branch"),
|
||||
sparse=sparse_checkout,
|
||||
)
|
||||
msg.good(f"Downloaded asset {dest}")
|
||||
else:
|
||||
url = asset.get("url")
|
||||
if not url:
|
||||
# project.yml defines asset without URL that the user has to place
|
||||
check_private_asset(dest, checksum)
|
||||
continue
|
||||
fetch_asset(project_path, url, dest, checksum)
|
||||
|
||||
|
||||
def check_private_asset(dest: Path, checksum: Optional[str] = None) -> None:
|
||||
"""Check and validate assets without a URL (private assets that the user
|
||||
has to provide themselves) and give feedback about the checksum.
|
||||
|
||||
dest (Path): Destination path of the asset.
|
||||
checksum (Optional[str]): Optional checksum of the expected file.
|
||||
"""
|
||||
if not Path(dest).exists():
|
||||
err = f"No URL provided for asset. You need to add this file yourself: {dest}"
|
||||
msg.warn(err)
|
||||
else:
|
||||
if not checksum:
|
||||
msg.good(f"Asset already exists: {dest}")
|
||||
elif checksum == get_checksum(dest):
|
||||
msg.good(f"Asset exists with matching checksum: {dest}")
|
||||
else:
|
||||
msg.fail(f"Asset available but with incorrect checksum: {dest}")
|
||||
|
||||
|
||||
def fetch_asset(
|
||||
project_path: Path, url: str, dest: Path, checksum: Optional[str] = None
|
||||
) -> None:
|
||||
"""Fetch an asset from a given URL or path. If a checksum is provided and a
|
||||
local file exists, it's only re-downloaded if the checksum doesn't match.
|
||||
|
||||
project_path (Path): Path to project directory.
|
||||
url (str): URL or path to asset.
|
||||
checksum (Optional[str]): Optional expected checksum of local file.
|
||||
RETURNS (Optional[Path]): The path to the fetched asset or None if fetching
|
||||
the asset failed.
|
||||
"""
|
||||
dest_path = (project_path / dest).resolve()
|
||||
if dest_path.exists():
|
||||
# If there's already a file, check for checksum
|
||||
if checksum:
|
||||
if checksum == get_checksum(dest_path):
|
||||
msg.good(f"Skipping download with matching checksum: {dest}")
|
||||
return
|
||||
else:
|
||||
# If there's not a checksum, make sure the file is a possibly valid size
|
||||
if os.path.getsize(dest_path) == 0:
|
||||
msg.warn(f"Asset exists but with size of 0 bytes, deleting: {dest}")
|
||||
os.remove(dest_path)
|
||||
# We might as well support the user here and create parent directories in
|
||||
# case the asset dir isn't listed as a dir to create in the project.yml
|
||||
if not dest_path.parent.exists():
|
||||
dest_path.parent.mkdir(parents=True)
|
||||
with working_dir(project_path):
|
||||
url = convert_asset_url(url)
|
||||
try:
|
||||
download_file(url, dest_path)
|
||||
msg.good(f"Downloaded asset {dest}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
if Path(url).exists() and Path(url).is_file():
|
||||
# If it's a local file, copy to destination
|
||||
shutil.copy(url, str(dest_path))
|
||||
msg.good(f"Copied local asset {dest}")
|
||||
else:
|
||||
msg.fail(f"Download failed: {dest}", e)
|
||||
if checksum and checksum != get_checksum(dest_path):
|
||||
msg.fail(f"Checksum doesn't match value defined in {PROJECT_FILE}: {dest}")
|
||||
|
||||
|
||||
def convert_asset_url(url: str) -> str:
|
||||
"""Check and convert the asset URL if needed.
|
||||
|
||||
url (str): The asset URL.
|
||||
RETURNS (str): The converted URL.
|
||||
"""
|
||||
# If the asset URL is a regular GitHub URL it's likely a mistake
|
||||
if (
|
||||
re.match(r"(http(s?)):\/\/github.com", url)
|
||||
and "releases/download" not in url
|
||||
and "/raw/" not in url
|
||||
):
|
||||
converted = url.replace("github.com", "raw.githubusercontent.com")
|
||||
converted = re.sub(r"/(tree|blob)/", "/", converted)
|
||||
msg.warn(
|
||||
"Downloading from a regular GitHub URL. This will only download "
|
||||
"the source of the page, not the actual file. Converting the URL "
|
||||
"to a raw URL.",
|
||||
converted,
|
||||
)
|
||||
return converted
|
||||
return url
|
|
@ -1,124 +0,0 @@
|
|||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from ... import about
|
||||
from ...util import ensure_path
|
||||
from .._util import (
|
||||
COMMAND,
|
||||
PROJECT_FILE,
|
||||
Arg,
|
||||
Opt,
|
||||
get_git_version,
|
||||
git_checkout,
|
||||
git_repo_branch_exists,
|
||||
project_cli,
|
||||
)
|
||||
|
||||
DEFAULT_REPO = about.__projects__
|
||||
DEFAULT_PROJECTS_BRANCH = about.__projects_branch__
|
||||
DEFAULT_BRANCHES = ["main", "master"]
|
||||
|
||||
|
||||
@project_cli.command("clone")
|
||||
def project_clone_cli(
|
||||
# fmt: off
|
||||
name: str = Arg(..., help="The name of the template to clone"),
|
||||
dest: Optional[Path] = Arg(None, help="Where to clone the project. Defaults to current working directory", exists=False),
|
||||
repo: str = Opt(DEFAULT_REPO, "--repo", "-r", help="The repository to clone from"),
|
||||
branch: Optional[str] = Opt(None, "--branch", "-b", help=f"The branch to clone from. If not provided, will attempt {', '.join(DEFAULT_BRANCHES)}"),
|
||||
sparse_checkout: bool = Opt(False, "--sparse", "-S", help="Use sparse Git checkout to only check out and clone the files needed. Requires Git v22.2+.")
|
||||
# fmt: on
|
||||
):
|
||||
"""Clone a project template from a repository. Calls into "git" and will
|
||||
only download the files from the given subdirectory. The GitHub repo
|
||||
defaults to the official spaCy template repo, but can be customized
|
||||
(including using a private repo).
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-clone
|
||||
"""
|
||||
if dest is None:
|
||||
dest = Path.cwd() / Path(name).parts[-1]
|
||||
if repo == DEFAULT_REPO and branch is None:
|
||||
branch = DEFAULT_PROJECTS_BRANCH
|
||||
|
||||
if branch is None:
|
||||
for default_branch in DEFAULT_BRANCHES:
|
||||
if git_repo_branch_exists(repo, default_branch):
|
||||
branch = default_branch
|
||||
break
|
||||
if branch is None:
|
||||
default_branches_msg = ", ".join(f"'{b}'" for b in DEFAULT_BRANCHES)
|
||||
msg.fail(
|
||||
"No branch provided and attempted default "
|
||||
f"branches {default_branches_msg} do not exist.",
|
||||
exits=1,
|
||||
)
|
||||
else:
|
||||
if not git_repo_branch_exists(repo, branch):
|
||||
msg.fail(f"repo: {repo} (branch: {branch}) does not exist.", exits=1)
|
||||
assert isinstance(branch, str)
|
||||
project_clone(name, dest, repo=repo, branch=branch, sparse_checkout=sparse_checkout)
|
||||
|
||||
|
||||
def project_clone(
|
||||
name: str,
|
||||
dest: Path,
|
||||
*,
|
||||
repo: str = about.__projects__,
|
||||
branch: str = about.__projects_branch__,
|
||||
sparse_checkout: bool = False,
|
||||
) -> None:
|
||||
"""Clone a project template from a repository.
|
||||
|
||||
name (str): Name of subdirectory to clone.
|
||||
dest (Path): Destination path of cloned project.
|
||||
repo (str): URL of Git repo containing project templates.
|
||||
branch (str): The branch to clone from
|
||||
"""
|
||||
dest = ensure_path(dest)
|
||||
check_clone(name, dest, repo)
|
||||
project_dir = dest.resolve()
|
||||
repo_name = re.sub(r"(http(s?)):\/\/github.com/", "", repo)
|
||||
try:
|
||||
git_checkout(repo, name, dest, branch=branch, sparse=sparse_checkout)
|
||||
except subprocess.CalledProcessError:
|
||||
err = f"Could not clone '{name}' from repo '{repo_name}' (branch '{branch}')"
|
||||
msg.fail(err, exits=1)
|
||||
msg.good(f"Cloned '{name}' from '{repo_name}' (branch '{branch}')", project_dir)
|
||||
if not (project_dir / PROJECT_FILE).exists():
|
||||
msg.warn(f"No {PROJECT_FILE} found in directory")
|
||||
else:
|
||||
msg.good(f"Your project is now ready!")
|
||||
print(f"To fetch the assets, run:\n{COMMAND} project assets {dest}")
|
||||
|
||||
|
||||
def check_clone(name: str, dest: Path, repo: str) -> None:
|
||||
"""Check and validate that the destination path can be used to clone. Will
|
||||
check that Git is available and that the destination path is suitable.
|
||||
|
||||
name (str): Name of the directory to clone from the repo.
|
||||
dest (Path): Local destination of cloned directory.
|
||||
repo (str): URL of the repo to clone from.
|
||||
"""
|
||||
git_err = (
|
||||
f"Cloning spaCy project templates requires Git and the 'git' command. "
|
||||
f"To clone a project without Git, copy the files from the '{name}' "
|
||||
f"directory in the {repo} to {dest} manually."
|
||||
)
|
||||
get_git_version(error=git_err)
|
||||
if not dest:
|
||||
msg.fail(f"Not a valid directory to clone project: {dest}", exits=1)
|
||||
if dest.exists():
|
||||
# Directory already exists (not allowed, clone needs to create it)
|
||||
msg.fail(f"Can't clone project, directory already exists: {dest}", exits=1)
|
||||
if not dest.parent.exists():
|
||||
# We're not creating parents, parent dir should exist
|
||||
msg.fail(
|
||||
f"Can't clone project, parent directory doesn't exist: {dest.parent}. "
|
||||
f"Create the necessary folder(s) first before continuing.",
|
||||
exits=1,
|
||||
)
|
|
@ -1,115 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
from wasabi import MarkdownRenderer, msg
|
||||
|
||||
from ...util import working_dir
|
||||
from .._util import PROJECT_FILE, Arg, Opt, load_project_config, project_cli
|
||||
|
||||
DOCS_URL = "https://spacy.io"
|
||||
INTRO_PROJECT = f"""The [`{PROJECT_FILE}`]({PROJECT_FILE}) defines the data assets required by the
|
||||
project, as well as the available commands and workflows. For details, see the
|
||||
[spaCy projects documentation]({DOCS_URL}/usage/projects)."""
|
||||
INTRO_COMMANDS = f"""The following commands are defined by the project. They
|
||||
can be executed using [`spacy project run [name]`]({DOCS_URL}/api/cli#project-run).
|
||||
Commands are only re-run if their inputs have changed."""
|
||||
INTRO_WORKFLOWS = f"""The following workflows are defined by the project. They
|
||||
can be executed using [`spacy project run [name]`]({DOCS_URL}/api/cli#project-run)
|
||||
and will run the specified commands in order. Commands are only re-run if their
|
||||
inputs have changed."""
|
||||
INTRO_ASSETS = f"""The following assets are defined by the project. They can
|
||||
be fetched by running [`spacy project assets`]({DOCS_URL}/api/cli#project-assets)
|
||||
in the project directory."""
|
||||
# These markers are added to the Markdown and can be used to update the file in
|
||||
# place if it already exists. Only the auto-generated part will be replaced.
|
||||
MARKER_START = "<!-- SPACY PROJECT: AUTO-GENERATED DOCS START (do not remove) -->"
|
||||
MARKER_END = "<!-- SPACY PROJECT: AUTO-GENERATED DOCS END (do not remove) -->"
|
||||
# If this marker is used in an existing README, it's ignored and not replaced
|
||||
MARKER_IGNORE = "<!-- SPACY PROJECT: IGNORE -->"
|
||||
|
||||
|
||||
@project_cli.command("document")
|
||||
def project_document_cli(
|
||||
# fmt: off
|
||||
project_dir: Path = Arg(Path.cwd(), help="Path to cloned project. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
output_file: Path = Opt("-", "--output", "-o", help="Path to output Markdown file for output. Defaults to - for standard output"),
|
||||
no_emoji: bool = Opt(False, "--no-emoji", "-NE", help="Don't use emoji")
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Auto-generate a README.md for a project. If the content is saved to a file,
|
||||
hidden markers are added so you can add custom content before or after the
|
||||
auto-generated section and only the auto-generated docs will be replaced
|
||||
when you re-run the command.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-document
|
||||
"""
|
||||
project_document(project_dir, output_file, no_emoji=no_emoji)
|
||||
|
||||
|
||||
def project_document(
|
||||
project_dir: Path, output_file: Path, *, no_emoji: bool = False
|
||||
) -> None:
|
||||
is_stdout = str(output_file) == "-"
|
||||
config = load_project_config(project_dir)
|
||||
md = MarkdownRenderer(no_emoji=no_emoji)
|
||||
md.add(MARKER_START)
|
||||
title = config.get("title")
|
||||
description = config.get("description")
|
||||
md.add(md.title(1, f"spaCy Project{f': {title}' if title else ''}", "🪐"))
|
||||
if description:
|
||||
md.add(description)
|
||||
md.add(md.title(2, PROJECT_FILE, "📋"))
|
||||
md.add(INTRO_PROJECT)
|
||||
# Commands
|
||||
cmds = config.get("commands", [])
|
||||
data = [(md.code(cmd["name"]), cmd.get("help", "")) for cmd in cmds]
|
||||
if data:
|
||||
md.add(md.title(3, "Commands", "⏯"))
|
||||
md.add(INTRO_COMMANDS)
|
||||
md.add(md.table(data, ["Command", "Description"]))
|
||||
# Workflows
|
||||
wfs = config.get("workflows", {}).items()
|
||||
data = [(md.code(n), " → ".join(md.code(w) for w in stp)) for n, stp in wfs]
|
||||
if data:
|
||||
md.add(md.title(3, "Workflows", "⏭"))
|
||||
md.add(INTRO_WORKFLOWS)
|
||||
md.add(md.table(data, ["Workflow", "Steps"]))
|
||||
# Assets
|
||||
assets = config.get("assets", [])
|
||||
data = []
|
||||
for a in assets:
|
||||
source = "Git" if a.get("git") else "URL" if a.get("url") else "Local"
|
||||
dest_path = a["dest"]
|
||||
dest = md.code(dest_path)
|
||||
if source == "Local":
|
||||
# Only link assets if they're in the repo
|
||||
with working_dir(project_dir) as p:
|
||||
if (p / dest_path).exists():
|
||||
dest = md.link(dest, dest_path)
|
||||
data.append((dest, source, a.get("description", "")))
|
||||
if data:
|
||||
md.add(md.title(3, "Assets", "🗂"))
|
||||
md.add(INTRO_ASSETS)
|
||||
md.add(md.table(data, ["File", "Source", "Description"]))
|
||||
md.add(MARKER_END)
|
||||
# Output result
|
||||
if is_stdout:
|
||||
print(md.text)
|
||||
else:
|
||||
content = md.text
|
||||
if output_file.exists():
|
||||
with output_file.open("r", encoding="utf8") as f:
|
||||
existing = f.read()
|
||||
if MARKER_IGNORE in existing:
|
||||
msg.warn("Found ignore marker in existing file: skipping", output_file)
|
||||
return
|
||||
if MARKER_START in existing and MARKER_END in existing:
|
||||
msg.info("Found existing file: only replacing auto-generated docs")
|
||||
before = existing.split(MARKER_START)[0]
|
||||
after = existing.split(MARKER_END)[1]
|
||||
content = f"{before}{content}{after}"
|
||||
else:
|
||||
msg.warn("Replacing existing file")
|
||||
with output_file.open("w", encoding="utf8") as f:
|
||||
f.write(content)
|
||||
msg.good("Saved project documentation", output_file)
|
|
@ -1,220 +0,0 @@
|
|||
"""This module contains helpers and subcommands for integrating spaCy projects
|
||||
with Data Version Controk (DVC). https://dvc.org"""
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from ...util import (
|
||||
SimpleFrozenList,
|
||||
join_command,
|
||||
run_command,
|
||||
split_command,
|
||||
working_dir,
|
||||
)
|
||||
from .._util import (
|
||||
COMMAND,
|
||||
NAME,
|
||||
PROJECT_FILE,
|
||||
Arg,
|
||||
Opt,
|
||||
get_hash,
|
||||
load_project_config,
|
||||
project_cli,
|
||||
)
|
||||
|
||||
DVC_CONFIG = "dvc.yaml"
|
||||
DVC_DIR = ".dvc"
|
||||
UPDATE_COMMAND = "dvc"
|
||||
DVC_CONFIG_COMMENT = f"""# This file is auto-generated by spaCy based on your {PROJECT_FILE}. If you've
|
||||
# edited your {PROJECT_FILE}, you can regenerate this file by running:
|
||||
# {COMMAND} project {UPDATE_COMMAND}"""
|
||||
|
||||
|
||||
@project_cli.command(UPDATE_COMMAND)
|
||||
def project_update_dvc_cli(
|
||||
# fmt: off
|
||||
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
workflow: Optional[str] = Arg(None, help=f"Name of workflow defined in {PROJECT_FILE}. Defaults to first workflow if not set."),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", help="Print more info"),
|
||||
quiet: bool = Opt(False, "--quiet", "-q", help="Print less info"),
|
||||
force: bool = Opt(False, "--force", "-F", help="Force update DVC config"),
|
||||
# fmt: on
|
||||
):
|
||||
"""Auto-generate Data Version Control (DVC) config. A DVC
|
||||
project can only define one pipeline, so you need to specify one workflow
|
||||
defined in the project.yml. If no workflow is specified, the first defined
|
||||
workflow is used. The DVC config will only be updated if the project.yml
|
||||
changed.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-dvc
|
||||
"""
|
||||
project_update_dvc(project_dir, workflow, verbose=verbose, quiet=quiet, force=force)
|
||||
|
||||
|
||||
def project_update_dvc(
|
||||
project_dir: Path,
|
||||
workflow: Optional[str] = None,
|
||||
*,
|
||||
verbose: bool = False,
|
||||
quiet: bool = False,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Update the auto-generated Data Version Control (DVC) config file. A DVC
|
||||
project can only define one pipeline, so you need to specify one workflow
|
||||
defined in the project.yml. Will only update the file if the checksum changed.
|
||||
|
||||
project_dir (Path): The project directory.
|
||||
workflow (Optional[str]): Optional name of workflow defined in project.yml.
|
||||
If not set, the first workflow will be used.
|
||||
verbose (bool): Print more info.
|
||||
quiet (bool): Print less info.
|
||||
force (bool): Force update DVC config.
|
||||
"""
|
||||
config = load_project_config(project_dir)
|
||||
updated = update_dvc_config(
|
||||
project_dir, config, workflow, verbose=verbose, quiet=quiet, force=force
|
||||
)
|
||||
help_msg = "To execute the workflow with DVC, run: dvc repro"
|
||||
if updated:
|
||||
msg.good(f"Updated DVC config from {PROJECT_FILE}", help_msg)
|
||||
else:
|
||||
msg.info(f"No changes found in {PROJECT_FILE}, no update needed", help_msg)
|
||||
|
||||
|
||||
def update_dvc_config(
|
||||
path: Path,
|
||||
config: Dict[str, Any],
|
||||
workflow: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
quiet: bool = False,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""Re-run the DVC commands in dry mode and update dvc.yaml file in the
|
||||
project directory. The file is auto-generated based on the config. The
|
||||
first line of the auto-generated file specifies the hash of the config
|
||||
dict, so if any of the config values change, the DVC config is regenerated.
|
||||
|
||||
path (Path): The path to the project directory.
|
||||
config (Dict[str, Any]): The loaded project.yml.
|
||||
verbose (bool): Whether to print additional info (via DVC).
|
||||
quiet (bool): Don't output anything (via DVC).
|
||||
force (bool): Force update, even if hashes match.
|
||||
RETURNS (bool): Whether the DVC config file was updated.
|
||||
"""
|
||||
ensure_dvc(path)
|
||||
workflows = config.get("workflows", {})
|
||||
workflow_names = list(workflows.keys())
|
||||
check_workflows(workflow_names, workflow)
|
||||
if not workflow:
|
||||
workflow = workflow_names[0]
|
||||
config_hash = get_hash(config)
|
||||
path = path.resolve()
|
||||
dvc_config_path = path / DVC_CONFIG
|
||||
if dvc_config_path.exists():
|
||||
# Check if the file was generated using the current config, if not, redo
|
||||
with dvc_config_path.open("r", encoding="utf8") as f:
|
||||
ref_hash = f.readline().strip().replace("# ", "")
|
||||
if ref_hash == config_hash and not force:
|
||||
return False # Nothing has changed in project.yml, don't need to update
|
||||
dvc_config_path.unlink()
|
||||
dvc_commands = []
|
||||
config_commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
|
||||
# some flags that apply to every command
|
||||
flags = []
|
||||
if verbose:
|
||||
flags.append("--verbose")
|
||||
if quiet:
|
||||
flags.append("--quiet")
|
||||
|
||||
for name in workflows[workflow]:
|
||||
command = config_commands[name]
|
||||
deps = command.get("deps", [])
|
||||
outputs = command.get("outputs", [])
|
||||
outputs_no_cache = command.get("outputs_no_cache", [])
|
||||
if not deps and not outputs and not outputs_no_cache:
|
||||
continue
|
||||
# Default to the working dir as the project path since dvc.yaml is auto-generated
|
||||
# and we don't want arbitrary paths in there
|
||||
project_cmd = ["python", "-m", NAME, "project", "run", name]
|
||||
deps_cmd = [c for cl in [["-d", p] for p in deps] for c in cl]
|
||||
outputs_cmd = [c for cl in [["-o", p] for p in outputs] for c in cl]
|
||||
outputs_nc_cmd = [c for cl in [["-O", p] for p in outputs_no_cache] for c in cl]
|
||||
|
||||
dvc_cmd = ["run", *flags, "-n", name, "-w", str(path), "--no-exec"]
|
||||
if command.get("no_skip"):
|
||||
dvc_cmd.append("--always-changed")
|
||||
full_cmd = [*dvc_cmd, *deps_cmd, *outputs_cmd, *outputs_nc_cmd, *project_cmd]
|
||||
dvc_commands.append(join_command(full_cmd))
|
||||
|
||||
if not dvc_commands:
|
||||
# If we don't check for this, then there will be an error when reading the
|
||||
# config, since DVC wouldn't create it.
|
||||
msg.fail(
|
||||
"No usable commands for DVC found. This can happen if none of your "
|
||||
"commands have dependencies or outputs.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
with working_dir(path):
|
||||
for c in dvc_commands:
|
||||
dvc_command = "dvc " + c
|
||||
run_command(dvc_command)
|
||||
with dvc_config_path.open("r+", encoding="utf8") as f:
|
||||
content = f.read()
|
||||
f.seek(0, 0)
|
||||
f.write(f"# {config_hash}\n{DVC_CONFIG_COMMENT}\n{content}")
|
||||
return True
|
||||
|
||||
|
||||
def check_workflows(workflows: List[str], workflow: Optional[str] = None) -> None:
|
||||
"""Validate workflows provided in project.yml and check that a given
|
||||
workflow can be used to generate a DVC config.
|
||||
|
||||
workflows (List[str]): Names of the available workflows.
|
||||
workflow (Optional[str]): The name of the workflow to convert.
|
||||
"""
|
||||
if not workflows:
|
||||
msg.fail(
|
||||
f"No workflows defined in {PROJECT_FILE}. To generate a DVC config, "
|
||||
f"define at least one list of commands.",
|
||||
exits=1,
|
||||
)
|
||||
if workflow is not None and workflow not in workflows:
|
||||
msg.fail(
|
||||
f"Workflow '{workflow}' not defined in {PROJECT_FILE}. "
|
||||
f"Available workflows: {', '.join(workflows)}",
|
||||
exits=1,
|
||||
)
|
||||
if not workflow:
|
||||
msg.warn(
|
||||
f"No workflow specified for DVC pipeline. Using the first workflow "
|
||||
f"defined in {PROJECT_FILE}: '{workflows[0]}'"
|
||||
)
|
||||
|
||||
|
||||
def ensure_dvc(project_dir: Path) -> None:
|
||||
"""Ensure that the "dvc" command is available and that the current project
|
||||
directory is an initialized DVC project.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(["dvc", "--version"], stdout=subprocess.DEVNULL)
|
||||
except Exception:
|
||||
msg.fail(
|
||||
"To use spaCy projects with DVC (Data Version Control), DVC needs "
|
||||
"to be installed and the 'dvc' command needs to be available",
|
||||
"You can install the Python package from pip (pip install dvc) or "
|
||||
"conda (conda install -c conda-forge dvc). For more details, see the "
|
||||
"documentation: https://dvc.org/doc/install",
|
||||
exits=1,
|
||||
)
|
||||
if not (project_dir / ".dvc").exists():
|
||||
msg.fail(
|
||||
"Project not initialized as a DVC project",
|
||||
"To initialize a DVC project, you can run 'dvc init' in the project "
|
||||
"directory. For more details, see the documentation: "
|
||||
"https://dvc.org/doc/command-reference/init",
|
||||
exits=1,
|
||||
)
|
|
@ -1,67 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from .._util import Arg, load_project_config, logger, project_cli
|
||||
from .remote_storage import RemoteStorage, get_command_hash
|
||||
from .run import update_lockfile
|
||||
|
||||
|
||||
@project_cli.command("pull")
|
||||
def project_pull_cli(
|
||||
# fmt: off
|
||||
remote: str = Arg("default", help="Name or path of remote storage"),
|
||||
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
# fmt: on
|
||||
):
|
||||
"""Retrieve available precomputed outputs from a remote storage.
|
||||
You can alias remotes in your project.yml by mapping them to storage paths.
|
||||
A storage can be anything that the smart-open library can upload to, e.g.
|
||||
AWS, Google Cloud Storage, SSH, local directories etc.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-pull
|
||||
"""
|
||||
for url, output_path in project_pull(project_dir, remote):
|
||||
if url is not None:
|
||||
msg.good(f"Pulled {output_path} from {url}")
|
||||
|
||||
|
||||
def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
|
||||
# TODO: We don't have tests for this :(. It would take a bit of mockery to
|
||||
# set up. I guess see if it breaks first?
|
||||
config = load_project_config(project_dir)
|
||||
if remote in config.get("remotes", {}):
|
||||
remote = config["remotes"][remote]
|
||||
storage = RemoteStorage(project_dir, remote)
|
||||
commands = list(config.get("commands", []))
|
||||
# We use a while loop here because we don't know how the commands
|
||||
# will be ordered. A command might need dependencies from one that's later
|
||||
# in the list.
|
||||
while commands:
|
||||
for i, cmd in enumerate(list(commands)):
|
||||
logger.debug("CMD: %s.", cmd["name"])
|
||||
deps = [project_dir / dep for dep in cmd.get("deps", [])]
|
||||
if all(dep.exists() for dep in deps):
|
||||
cmd_hash = get_command_hash("", "", deps, cmd["script"])
|
||||
for output_path in cmd.get("outputs", []):
|
||||
url = storage.pull(output_path, command_hash=cmd_hash)
|
||||
logger.debug(
|
||||
"URL: %s for %s with command hash %s",
|
||||
url,
|
||||
output_path,
|
||||
cmd_hash,
|
||||
)
|
||||
yield url, output_path
|
||||
|
||||
out_locs = [project_dir / out for out in cmd.get("outputs", [])]
|
||||
if all(loc.exists() for loc in out_locs):
|
||||
update_lockfile(project_dir, cmd)
|
||||
# We remove the command from the list here, and break, so that
|
||||
# we iterate over the loop again.
|
||||
commands.pop(i)
|
||||
break
|
||||
else:
|
||||
logger.debug("Dependency missing. Skipping %s outputs.", cmd["name"])
|
||||
else:
|
||||
# If we didn't break the for loop, break the while loop.
|
||||
break
|
|
@ -1,69 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from .._util import Arg, load_project_config, logger, project_cli
|
||||
from .remote_storage import RemoteStorage, get_command_hash, get_content_hash
|
||||
|
||||
|
||||
@project_cli.command("push")
|
||||
def project_push_cli(
|
||||
# fmt: off
|
||||
remote: str = Arg("default", help="Name or path of remote storage"),
|
||||
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
# fmt: on
|
||||
):
|
||||
"""Persist outputs to a remote storage. You can alias remotes in your
|
||||
project.yml by mapping them to storage paths. A storage can be anything that
|
||||
the smart-open library can upload to, e.g. AWS, Google Cloud Storage, SSH,
|
||||
local directories etc.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-push
|
||||
"""
|
||||
for output_path, url in project_push(project_dir, remote):
|
||||
if url is None:
|
||||
msg.info(f"Skipping {output_path}")
|
||||
else:
|
||||
msg.good(f"Pushed {output_path} to {url}")
|
||||
|
||||
|
||||
def project_push(project_dir: Path, remote: str):
|
||||
"""Persist outputs to a remote storage. You can alias remotes in your project.yml
|
||||
by mapping them to storage paths. A storage can be anything that the smart-open
|
||||
library can upload to, e.g. gcs, aws, ssh, local directories etc
|
||||
"""
|
||||
config = load_project_config(project_dir)
|
||||
if remote in config.get("remotes", {}):
|
||||
remote = config["remotes"][remote]
|
||||
storage = RemoteStorage(project_dir, remote)
|
||||
for cmd in config.get("commands", []):
|
||||
logger.debug("CMD: %s", cmd["name"])
|
||||
deps = [project_dir / dep for dep in cmd.get("deps", [])]
|
||||
if any(not dep.exists() for dep in deps):
|
||||
logger.debug("Dependency missing. Skipping %s outputs", cmd["name"])
|
||||
continue
|
||||
cmd_hash = get_command_hash(
|
||||
"", "", [project_dir / dep for dep in cmd.get("deps", [])], cmd["script"]
|
||||
)
|
||||
logger.debug("CMD_HASH: %s", cmd_hash)
|
||||
for output_path in cmd.get("outputs", []):
|
||||
output_loc = project_dir / output_path
|
||||
if output_loc.exists() and _is_not_empty_dir(output_loc):
|
||||
url = storage.push(
|
||||
output_path,
|
||||
command_hash=cmd_hash,
|
||||
content_hash=get_content_hash(output_loc),
|
||||
)
|
||||
logger.debug(
|
||||
"URL: %s for output %s with cmd_hash %s", url, output_path, cmd_hash
|
||||
)
|
||||
yield output_path, url
|
||||
|
||||
|
||||
def _is_not_empty_dir(loc: Path):
|
||||
if not loc.is_dir():
|
||||
return True
|
||||
elif any(_is_not_empty_dir(child) for child in loc.iterdir()):
|
||||
return True
|
||||
else:
|
||||
return False
|
|
@ -1,212 +0,0 @@
|
|||
import hashlib
|
||||
import os
|
||||
import site
|
||||
import tarfile
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from ... import about
|
||||
from ...errors import Errors
|
||||
from ...git_info import GIT_VERSION
|
||||
from ...util import ENV_VARS, check_bool_env_var, get_minor_version
|
||||
from .._util import (
|
||||
download_file,
|
||||
ensure_pathy,
|
||||
get_checksum,
|
||||
get_hash,
|
||||
make_tempdir,
|
||||
upload_file,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathy import FluidPath # noqa: F401
|
||||
|
||||
|
||||
class RemoteStorage:
|
||||
"""Push and pull outputs to and from a remote file storage.
|
||||
|
||||
Remotes can be anything that `smart-open` can support: AWS, GCS, file system,
|
||||
ssh, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: Path, url: str, *, compression="gz"):
|
||||
self.root = project_root
|
||||
self.url = ensure_pathy(url)
|
||||
self.compression = compression
|
||||
|
||||
def push(self, path: Path, command_hash: str, content_hash: str) -> "FluidPath":
|
||||
"""Compress a file or directory within a project and upload it to a remote
|
||||
storage. If an object exists at the full URL, nothing is done.
|
||||
|
||||
Within the remote storage, files are addressed by their project path
|
||||
(url encoded) and two user-supplied hashes, representing their creation
|
||||
context and their file contents. If the URL already exists, the data is
|
||||
not uploaded. Paths are archived and compressed prior to upload.
|
||||
"""
|
||||
loc = self.root / path
|
||||
if not loc.exists():
|
||||
raise IOError(f"Cannot push {loc}: does not exist.")
|
||||
url = self.make_url(path, command_hash, content_hash)
|
||||
if url.exists():
|
||||
return url
|
||||
tmp: Path
|
||||
with make_tempdir() as tmp:
|
||||
tar_loc = tmp / self.encode_name(str(path))
|
||||
mode_string = f"w:{self.compression}" if self.compression else "w"
|
||||
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
|
||||
tar_file.add(str(loc), arcname=str(path))
|
||||
upload_file(tar_loc, url)
|
||||
return url
|
||||
|
||||
def pull(
|
||||
self,
|
||||
path: Path,
|
||||
*,
|
||||
command_hash: Optional[str] = None,
|
||||
content_hash: Optional[str] = None,
|
||||
) -> Optional["FluidPath"]:
|
||||
"""Retrieve a file from the remote cache. If the file already exists,
|
||||
nothing is done.
|
||||
|
||||
If the command_hash and/or content_hash are specified, only matching
|
||||
results are returned. If no results are available, an error is raised.
|
||||
"""
|
||||
dest = self.root / path
|
||||
if dest.exists():
|
||||
return None
|
||||
url = self.find(path, command_hash=command_hash, content_hash=content_hash)
|
||||
if url is None:
|
||||
return url
|
||||
else:
|
||||
# Make sure the destination exists
|
||||
if not dest.parent.exists():
|
||||
dest.parent.mkdir(parents=True)
|
||||
tmp: Path
|
||||
with make_tempdir() as tmp:
|
||||
tar_loc = tmp / url.parts[-1]
|
||||
download_file(url, tar_loc)
|
||||
mode_string = f"r:{self.compression}" if self.compression else "r"
|
||||
with tarfile.open(tar_loc, mode=mode_string) as tar_file:
|
||||
# This requires that the path is added correctly, relative
|
||||
# to root. This is how we set things up in push()
|
||||
|
||||
# Disallow paths outside the current directory for the tar
|
||||
# file (CVE-2007-4559, directory traversal vulnerability)
|
||||
def is_within_directory(directory, target):
|
||||
abs_directory = os.path.abspath(directory)
|
||||
abs_target = os.path.abspath(target)
|
||||
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||||
return prefix == abs_directory
|
||||
|
||||
def safe_extract(tar, path):
|
||||
for member in tar.getmembers():
|
||||
member_path = os.path.join(path, member.name)
|
||||
if not is_within_directory(path, member_path):
|
||||
raise ValueError(Errors.E852)
|
||||
tar.extractall(path)
|
||||
|
||||
safe_extract(tar_file, self.root)
|
||||
return url
|
||||
|
||||
def find(
|
||||
self,
|
||||
path: Path,
|
||||
*,
|
||||
command_hash: Optional[str] = None,
|
||||
content_hash: Optional[str] = None,
|
||||
) -> Optional["FluidPath"]:
|
||||
"""Find the best matching version of a file within the storage,
|
||||
or `None` if no match can be found. If both the creation and content hash
|
||||
are specified, only exact matches will be returned. Otherwise, the most
|
||||
recent matching file is preferred.
|
||||
"""
|
||||
name = self.encode_name(str(path))
|
||||
urls = []
|
||||
if command_hash is not None and content_hash is not None:
|
||||
url = self.url / name / command_hash / content_hash
|
||||
urls = [url] if url.exists() else []
|
||||
elif command_hash is not None:
|
||||
if (self.url / name / command_hash).exists():
|
||||
urls = list((self.url / name / command_hash).iterdir())
|
||||
else:
|
||||
if (self.url / name).exists():
|
||||
for sub_dir in (self.url / name).iterdir():
|
||||
urls.extend(sub_dir.iterdir())
|
||||
if content_hash is not None:
|
||||
urls = [url for url in urls if url.parts[-1] == content_hash]
|
||||
if len(urls) >= 2:
|
||||
try:
|
||||
urls.sort(key=lambda x: x.stat().last_modified) # type: ignore
|
||||
except Exception:
|
||||
msg.warn(
|
||||
"Unable to sort remote files by last modified. The file(s) "
|
||||
"pulled from the cache may not be the most recent."
|
||||
)
|
||||
return urls[-1] if urls else None
|
||||
|
||||
def make_url(self, path: Path, command_hash: str, content_hash: str) -> "FluidPath":
|
||||
"""Construct a URL from a subpath, a creation hash and a content hash."""
|
||||
return self.url / self.encode_name(str(path)) / command_hash / content_hash
|
||||
|
||||
def encode_name(self, name: str) -> str:
|
||||
"""Encode a subpath into a URL-safe name."""
|
||||
return urllib.parse.quote_plus(name)
|
||||
|
||||
|
||||
def get_content_hash(loc: Path) -> str:
|
||||
return get_checksum(loc)
|
||||
|
||||
|
||||
def get_command_hash(
|
||||
site_hash: str, env_hash: str, deps: List[Path], cmd: List[str]
|
||||
) -> str:
|
||||
"""Create a hash representing the execution of a command. This includes the
|
||||
currently installed packages, whatever environment variables have been marked
|
||||
as relevant, and the command.
|
||||
"""
|
||||
if check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION):
|
||||
spacy_v = GIT_VERSION
|
||||
else:
|
||||
spacy_v = str(get_minor_version(about.__version__) or "")
|
||||
dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
|
||||
hashes = [spacy_v, site_hash, env_hash] + dep_checksums
|
||||
hashes.extend(cmd)
|
||||
creation_bytes = "".join(hashes).encode("utf8")
|
||||
return hashlib.md5(creation_bytes).hexdigest()
|
||||
|
||||
|
||||
def get_site_hash():
|
||||
"""Hash the current Python environment's site-packages contents, including
|
||||
the name and version of the libraries. The list we're hashing is what
|
||||
`pip freeze` would output.
|
||||
"""
|
||||
site_dirs = site.getsitepackages()
|
||||
if site.ENABLE_USER_SITE:
|
||||
site_dirs.extend(site.getusersitepackages())
|
||||
packages = set()
|
||||
for site_dir in site_dirs:
|
||||
site_dir = Path(site_dir)
|
||||
for subpath in site_dir.iterdir():
|
||||
if subpath.parts[-1].endswith("dist-info"):
|
||||
packages.add(subpath.parts[-1].replace(".dist-info", ""))
|
||||
package_bytes = "".join(sorted(packages)).encode("utf8")
|
||||
return hashlib.md5sum(package_bytes).hexdigest()
|
||||
|
||||
|
||||
def get_env_hash(env: Dict[str, str]) -> str:
|
||||
"""Construct a hash of the environment variables that will be passed into
|
||||
the commands.
|
||||
|
||||
Values in the env dict may be references to the current os.environ, using
|
||||
the syntax $ENV_VAR to mean os.environ[ENV_VAR]
|
||||
"""
|
||||
env_vars = {}
|
||||
for key, value in env.items():
|
||||
if value.startswith("$"):
|
||||
env_vars[key] = os.environ.get(value[1:], "")
|
||||
else:
|
||||
env_vars[key] = value
|
||||
return get_hash(env_vars)
|
|
@ -1,379 +0,0 @@
|
|||
import os.path
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import srsly
|
||||
import typer
|
||||
from wasabi import msg
|
||||
from wasabi.util import locale_escape
|
||||
|
||||
from ... import about
|
||||
from ...git_info import GIT_VERSION
|
||||
from ...util import (
|
||||
ENV_VARS,
|
||||
SimpleFrozenDict,
|
||||
SimpleFrozenList,
|
||||
check_bool_env_var,
|
||||
is_cwd,
|
||||
is_minor_version_match,
|
||||
join_command,
|
||||
run_command,
|
||||
split_command,
|
||||
working_dir,
|
||||
)
|
||||
from .._util import (
|
||||
COMMAND,
|
||||
PROJECT_FILE,
|
||||
PROJECT_LOCK,
|
||||
Arg,
|
||||
Opt,
|
||||
get_checksum,
|
||||
get_hash,
|
||||
load_project_config,
|
||||
parse_config_overrides,
|
||||
project_cli,
|
||||
)
|
||||
|
||||
|
||||
@project_cli.command(
|
||||
"run", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
|
||||
)
|
||||
def project_run_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
subcommand: str = Arg(None, help=f"Name of command defined in the {PROJECT_FILE}"),
|
||||
project_dir: Path = Arg(Path.cwd(), help="Location of project directory. Defaults to current working directory.", exists=True, file_okay=False),
|
||||
force: bool = Opt(False, "--force", "-F", help="Force re-running steps, even if nothing changed"),
|
||||
dry: bool = Opt(False, "--dry", "-D", help="Perform a dry run and don't execute scripts"),
|
||||
show_help: bool = Opt(False, "--help", help="Show help message and available subcommands")
|
||||
# fmt: on
|
||||
):
|
||||
"""Run a named command or workflow defined in the project.yml. If a workflow
|
||||
name is specified, all commands in the workflow are run, in order. If
|
||||
commands define dependencies and/or outputs, they will only be re-run if
|
||||
state has changed.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#project-run
|
||||
"""
|
||||
if show_help or not subcommand:
|
||||
print_run_help(project_dir, subcommand)
|
||||
else:
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
project_run(project_dir, subcommand, overrides=overrides, force=force, dry=dry)
|
||||
|
||||
|
||||
def project_run(
|
||||
project_dir: Path,
|
||||
subcommand: str,
|
||||
*,
|
||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
||||
force: bool = False,
|
||||
dry: bool = False,
|
||||
capture: bool = False,
|
||||
skip_requirements_check: bool = False,
|
||||
) -> None:
|
||||
"""Run a named script defined in the project.yml. If the script is part
|
||||
of the default pipeline (defined in the "run" section), DVC is used to
|
||||
execute the command, so it can determine whether to rerun it. It then
|
||||
calls into "exec" to execute it.
|
||||
|
||||
project_dir (Path): Path to project directory.
|
||||
subcommand (str): Name of command to run.
|
||||
overrides (Dict[str, Any]): Optional config overrides.
|
||||
force (bool): Force re-running, even if nothing changed.
|
||||
dry (bool): Perform a dry run and don't execute commands.
|
||||
capture (bool): Whether to capture the output and errors of individual commands.
|
||||
If False, the stdout and stderr will not be redirected, and if there's an error,
|
||||
sys.exit will be called with the return code. You should use capture=False
|
||||
when you want to turn over execution to the command, and capture=True
|
||||
when you want to run the command more like a function.
|
||||
skip_requirements_check (bool): Whether to skip the requirements check.
|
||||
"""
|
||||
config = load_project_config(project_dir, overrides=overrides)
|
||||
commands = {cmd["name"]: cmd for cmd in config.get("commands", [])}
|
||||
workflows = config.get("workflows", {})
|
||||
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
|
||||
|
||||
req_path = project_dir / "requirements.txt"
|
||||
if not skip_requirements_check:
|
||||
if config.get("check_requirements", True) and os.path.exists(req_path):
|
||||
with req_path.open() as requirements_file:
|
||||
_check_requirements([req.strip() for req in requirements_file])
|
||||
|
||||
if subcommand in workflows:
|
||||
msg.info(f"Running workflow '{subcommand}'")
|
||||
for cmd in workflows[subcommand]:
|
||||
project_run(
|
||||
project_dir,
|
||||
cmd,
|
||||
overrides=overrides,
|
||||
force=force,
|
||||
dry=dry,
|
||||
capture=capture,
|
||||
skip_requirements_check=True,
|
||||
)
|
||||
else:
|
||||
cmd = commands[subcommand]
|
||||
for dep in cmd.get("deps", []):
|
||||
if not (project_dir / dep).exists():
|
||||
err = f"Missing dependency specified by command '{subcommand}': {dep}"
|
||||
err_help = "Maybe you forgot to run the 'project assets' command or a previous step?"
|
||||
err_exits = 1 if not dry else None
|
||||
msg.fail(err, err_help, exits=err_exits)
|
||||
check_spacy_commit = check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION)
|
||||
with working_dir(project_dir) as current_dir:
|
||||
msg.divider(subcommand)
|
||||
rerun = check_rerun(current_dir, cmd, check_spacy_commit=check_spacy_commit)
|
||||
if not rerun and not force:
|
||||
msg.info(f"Skipping '{cmd['name']}': nothing changed")
|
||||
else:
|
||||
run_commands(cmd["script"], dry=dry, capture=capture)
|
||||
if not dry:
|
||||
update_lockfile(current_dir, cmd)
|
||||
|
||||
|
||||
def print_run_help(project_dir: Path, subcommand: Optional[str] = None) -> None:
|
||||
"""Simulate a CLI help prompt using the info available in the project.yml.
|
||||
|
||||
project_dir (Path): The project directory.
|
||||
subcommand (Optional[str]): The subcommand or None. If a subcommand is
|
||||
provided, the subcommand help is shown. Otherwise, the top-level help
|
||||
and a list of available commands is printed.
|
||||
"""
|
||||
config = load_project_config(project_dir)
|
||||
config_commands = config.get("commands", [])
|
||||
commands = {cmd["name"]: cmd for cmd in config_commands}
|
||||
workflows = config.get("workflows", {})
|
||||
project_loc = "" if is_cwd(project_dir) else project_dir
|
||||
if subcommand:
|
||||
validate_subcommand(list(commands.keys()), list(workflows.keys()), subcommand)
|
||||
print(f"Usage: {COMMAND} project run {subcommand} {project_loc}")
|
||||
if subcommand in commands:
|
||||
help_text = commands[subcommand].get("help")
|
||||
if help_text:
|
||||
print(f"\n{help_text}\n")
|
||||
elif subcommand in workflows:
|
||||
steps = workflows[subcommand]
|
||||
print(f"\nWorkflow consisting of {len(steps)} commands:")
|
||||
steps_data = [
|
||||
(f"{i + 1}. {step}", commands[step].get("help", ""))
|
||||
for i, step in enumerate(steps)
|
||||
]
|
||||
msg.table(steps_data)
|
||||
help_cmd = f"{COMMAND} project run [COMMAND] {project_loc} --help"
|
||||
print(f"For command details, run: {help_cmd}")
|
||||
else:
|
||||
print("")
|
||||
title = config.get("title")
|
||||
if title:
|
||||
print(f"{locale_escape(title)}\n")
|
||||
if config_commands:
|
||||
print(f"Available commands in {PROJECT_FILE}")
|
||||
print(f"Usage: {COMMAND} project run [COMMAND] {project_loc}")
|
||||
msg.table([(cmd["name"], cmd.get("help", "")) for cmd in config_commands])
|
||||
if workflows:
|
||||
print(f"Available workflows in {PROJECT_FILE}")
|
||||
print(f"Usage: {COMMAND} project run [WORKFLOW] {project_loc}")
|
||||
msg.table([(name, " -> ".join(steps)) for name, steps in workflows.items()])
|
||||
|
||||
|
||||
def run_commands(
|
||||
commands: Iterable[str] = SimpleFrozenList(),
|
||||
silent: bool = False,
|
||||
dry: bool = False,
|
||||
capture: bool = False,
|
||||
) -> None:
|
||||
"""Run a sequence of commands in a subprocess, in order.
|
||||
|
||||
commands (List[str]): The string commands.
|
||||
silent (bool): Don't print the commands.
|
||||
dry (bool): Perform a dry run and don't execut anything.
|
||||
capture (bool): Whether to capture the output and errors of individual commands.
|
||||
If False, the stdout and stderr will not be redirected, and if there's an error,
|
||||
sys.exit will be called with the return code. You should use capture=False
|
||||
when you want to turn over execution to the command, and capture=True
|
||||
when you want to run the command more like a function.
|
||||
"""
|
||||
for c in commands:
|
||||
command = split_command(c)
|
||||
# Not sure if this is needed or a good idea. Motivation: users may often
|
||||
# use commands in their config that reference "python" and we want to
|
||||
# make sure that it's always executing the same Python that spaCy is
|
||||
# executed with and the pip in the same env, not some other Python/pip.
|
||||
# Also ensures cross-compatibility if user 1 writes "python3" (because
|
||||
# that's how it's set up on their system), and user 2 without the
|
||||
# shortcut tries to re-run the command.
|
||||
if len(command) and command[0] in ("python", "python3"):
|
||||
command[0] = sys.executable
|
||||
elif len(command) and command[0] in ("pip", "pip3"):
|
||||
command = [sys.executable, "-m", "pip", *command[1:]]
|
||||
if not silent:
|
||||
print(f"Running command: {join_command(command)}")
|
||||
if not dry:
|
||||
run_command(command, capture=capture)
|
||||
|
||||
|
||||
def validate_subcommand(
|
||||
commands: Sequence[str], workflows: Sequence[str], subcommand: str
|
||||
) -> None:
|
||||
"""Check that a subcommand is valid and defined. Raises an error otherwise.
|
||||
|
||||
commands (Sequence[str]): The available commands.
|
||||
subcommand (str): The subcommand.
|
||||
"""
|
||||
if not commands and not workflows:
|
||||
msg.fail(f"No commands or workflows defined in {PROJECT_FILE}", exits=1)
|
||||
if subcommand not in commands and subcommand not in workflows:
|
||||
help_msg = []
|
||||
if subcommand in ["assets", "asset"]:
|
||||
help_msg.append("Did you mean to run: python -m spacy project assets?")
|
||||
if commands:
|
||||
help_msg.append(f"Available commands: {', '.join(commands)}")
|
||||
if workflows:
|
||||
help_msg.append(f"Available workflows: {', '.join(workflows)}")
|
||||
msg.fail(
|
||||
f"Can't find command or workflow '{subcommand}' in {PROJECT_FILE}",
|
||||
". ".join(help_msg),
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def check_rerun(
|
||||
project_dir: Path,
|
||||
command: Dict[str, Any],
|
||||
*,
|
||||
check_spacy_version: bool = True,
|
||||
check_spacy_commit: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a command should be rerun because its settings or inputs/outputs
|
||||
changed.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
strict_version (bool):
|
||||
RETURNS (bool): Whether to re-run the command.
|
||||
"""
|
||||
# Always rerun if no-skip is set
|
||||
if command.get("no_skip", False):
|
||||
return True
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists(): # We don't have a lockfile, run command
|
||||
return True
|
||||
data = srsly.read_yaml(lock_path)
|
||||
if command["name"] not in data: # We don't have info about this command
|
||||
return True
|
||||
entry = data[command["name"]]
|
||||
# Always run commands with no outputs (otherwise they'd always be skipped)
|
||||
if not entry.get("outs", []):
|
||||
return True
|
||||
# Always rerun if spaCy version or commit hash changed
|
||||
spacy_v = entry.get("spacy_version")
|
||||
commit = entry.get("spacy_git_version")
|
||||
if check_spacy_version and not is_minor_version_match(spacy_v, about.__version__):
|
||||
info = f"({spacy_v} in {PROJECT_LOCK}, {about.__version__} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy minor version changed {info}")
|
||||
return True
|
||||
if check_spacy_commit and commit != GIT_VERSION:
|
||||
info = f"({commit} in {PROJECT_LOCK}, {GIT_VERSION} current)"
|
||||
msg.info(f"Re-running '{command['name']}': spaCy commit changed {info}")
|
||||
return True
|
||||
# If the entry in the lockfile matches the lockfile entry that would be
|
||||
# generated from the current command, we don't rerun because it means that
|
||||
# all inputs/outputs, hashes and scripts are the same and nothing changed
|
||||
lock_entry = get_lock_entry(project_dir, command)
|
||||
exclude = ["spacy_version", "spacy_git_version"]
|
||||
return get_hash(lock_entry, exclude=exclude) != get_hash(entry, exclude=exclude)
|
||||
|
||||
|
||||
def update_lockfile(project_dir: Path, command: Dict[str, Any]) -> None:
|
||||
"""Update the lockfile after running a command. Will create a lockfile if
|
||||
it doesn't yet exist and will add an entry for the current command, its
|
||||
script and dependencies/outputs.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
"""
|
||||
lock_path = project_dir / PROJECT_LOCK
|
||||
if not lock_path.exists():
|
||||
srsly.write_yaml(lock_path, {})
|
||||
data = {}
|
||||
else:
|
||||
data = srsly.read_yaml(lock_path)
|
||||
data[command["name"]] = get_lock_entry(project_dir, command)
|
||||
srsly.write_yaml(lock_path, data)
|
||||
|
||||
|
||||
def get_lock_entry(project_dir: Path, command: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get a lockfile entry for a given command. An entry includes the command,
|
||||
the script (command steps) and a list of dependencies and outputs with
|
||||
their paths and file hashes, if available. The format is based on the
|
||||
dvc.lock files, to keep things consistent.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
command (Dict[str, Any]): The command, as defined in the project.yml.
|
||||
RETURNS (Dict[str, Any]): The lockfile entry.
|
||||
"""
|
||||
deps = get_fileinfo(project_dir, command.get("deps", []))
|
||||
outs = get_fileinfo(project_dir, command.get("outputs", []))
|
||||
outs_nc = get_fileinfo(project_dir, command.get("outputs_no_cache", []))
|
||||
return {
|
||||
"cmd": f"{COMMAND} run {command['name']}",
|
||||
"script": command["script"],
|
||||
"deps": deps,
|
||||
"outs": [*outs, *outs_nc],
|
||||
"spacy_version": about.__version__,
|
||||
"spacy_git_version": GIT_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def get_fileinfo(project_dir: Path, paths: List[str]) -> List[Dict[str, Optional[str]]]:
|
||||
"""Generate the file information for a list of paths (dependencies, outputs).
|
||||
Includes the file path and the file's checksum.
|
||||
|
||||
project_dir (Path): The current project directory.
|
||||
paths (List[str]): The file paths.
|
||||
RETURNS (List[Dict[str, str]]): The lockfile entry for a file.
|
||||
"""
|
||||
data = []
|
||||
for path in paths:
|
||||
file_path = project_dir / path
|
||||
md5 = get_checksum(file_path) if file_path.exists() else None
|
||||
data.append({"path": path, "md5": md5})
|
||||
return data
|
||||
|
||||
|
||||
def _check_requirements(requirements: List[str]) -> Tuple[bool, bool]:
|
||||
"""Checks whether requirements are installed and free of version conflicts.
|
||||
requirements (List[str]): List of requirements.
|
||||
RETURNS (Tuple[bool, bool]): Whether (1) any packages couldn't be imported, (2) any packages with version conflicts
|
||||
exist.
|
||||
"""
|
||||
import pkg_resources
|
||||
|
||||
failed_pkgs_msgs: List[str] = []
|
||||
conflicting_pkgs_msgs: List[str] = []
|
||||
|
||||
for req in requirements:
|
||||
try:
|
||||
pkg_resources.require(req)
|
||||
except pkg_resources.DistributionNotFound as dnf:
|
||||
failed_pkgs_msgs.append(dnf.report())
|
||||
except pkg_resources.VersionConflict as vc:
|
||||
conflicting_pkgs_msgs.append(vc.report())
|
||||
except Exception:
|
||||
msg.warn(
|
||||
f"Unable to check requirement: {req} "
|
||||
"Checks are currently limited to requirement specifiers "
|
||||
"(PEP 508)"
|
||||
)
|
||||
|
||||
if len(failed_pkgs_msgs) or len(conflicting_pkgs_msgs):
|
||||
msg.warn(
|
||||
title="Missing requirements or requirement conflicts detected. Make sure your Python environment is set up "
|
||||
"correctly and you installed all requirements specified in your project's requirements.txt: "
|
||||
)
|
||||
for pgk_msg in failed_pkgs_msgs + conflicting_pkgs_msgs:
|
||||
msg.text(pgk_msg)
|
||||
|
||||
return len(failed_pkgs_msgs) > 0, len(conflicting_pkgs_msgs) > 0
|
|
@ -47,7 +47,8 @@ def train_cli(
|
|||
|
||||
DOCS: https://spacy.io/api/cli#train
|
||||
"""
|
||||
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
if verbose:
|
||||
util.logger.setLevel(logging.DEBUG)
|
||||
overrides = parse_config_overrides(ctx.args)
|
||||
import_code(code_path)
|
||||
train(config_path, output_path, use_gpu=use_gpu, overrides=overrides)
|
||||
|
|
|
@ -26,6 +26,9 @@ batch_size = 1000
|
|||
[nlp.tokenizer]
|
||||
@tokenizers = "spacy.Tokenizer.v1"
|
||||
|
||||
[nlp.vectors]
|
||||
@vectors = "spacy.Vectors.v1"
|
||||
|
||||
# The pipeline components and their models
|
||||
[components]
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import itertools
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -218,7 +217,7 @@ class SpanRenderer:
|
|||
+ (self.offset_step * (len(entities) - 1))
|
||||
)
|
||||
markup += self.span_template.format(
|
||||
text=token["text"],
|
||||
text=escape_html(token["text"]),
|
||||
span_slices=slices,
|
||||
span_starts=starts,
|
||||
total_height=total_height,
|
||||
|
@ -314,6 +313,8 @@ class DependencyRenderer:
|
|||
self.lang = settings.get("lang", DEFAULT_LANG)
|
||||
render_id = f"{id_prefix}-{i}"
|
||||
svg = self.render_svg(render_id, p["words"], p["arcs"])
|
||||
if p.get("title"):
|
||||
svg = TPL_TITLE.format(title=p.get("title")) + svg
|
||||
rendered.append(svg)
|
||||
if page:
|
||||
content = "".join([TPL_FIGURE.format(content=svg) for svg in rendered])
|
||||
|
@ -566,7 +567,7 @@ class EntityRenderer:
|
|||
for i, fragment in enumerate(fragments):
|
||||
markup += escape_html(fragment)
|
||||
if len(fragments) > 1 and i != len(fragments) - 1:
|
||||
markup += "</br>"
|
||||
markup += "<br>"
|
||||
if self.ents is None or label.upper() in self.ents:
|
||||
color = self.colors.get(label.upper(), self.default_color)
|
||||
ent_settings = {
|
||||
|
@ -584,7 +585,7 @@ class EntityRenderer:
|
|||
for i, fragment in enumerate(fragments):
|
||||
markup += escape_html(fragment)
|
||||
if len(fragments) > 1 and i != len(fragments) - 1:
|
||||
markup += "</br>"
|
||||
markup += "<br>"
|
||||
markup = TPL_ENTS.format(content=markup, dir=self.direction)
|
||||
if title:
|
||||
markup = TPL_TITLE.format(title=title) + markup
|
||||
|
|
|
@ -219,6 +219,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
W125 = ("The StaticVectors key_attr is no longer used. To set a custom "
|
||||
"key attribute for vectors, configure it through Vectors(attr=) or "
|
||||
"'spacy init vectors --attr'")
|
||||
W126 = ("These keys are unsupported: {unsupported}")
|
||||
|
||||
|
||||
class Errors(metaclass=ErrorsWithCodes):
|
||||
|
@ -553,12 +554,12 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"during training, make sure to include it in 'annotating components'")
|
||||
|
||||
# New errors added in v3.x
|
||||
E849 = ("The vocab only supports {method} for vectors of type "
|
||||
"spacy.vectors.Vectors, not {vectors_type}.")
|
||||
E850 = ("The PretrainVectors objective currently only supports default or "
|
||||
"floret vectors, not {mode} vectors.")
|
||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||
"traversal.")
|
||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||
"not permitted in factory names.")
|
||||
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "
|
||||
|
@ -981,6 +982,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
" 'min_length': {min_length}, 'max_length': {max_length}")
|
||||
E1054 = ("The text, including whitespace, must match between reference and "
|
||||
"predicted docs when training {component}.")
|
||||
E1055 = ("The 'replace_listener' callback expects {num_params} parameters, "
|
||||
"but only callbacks with one or three parameters are supported")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -4,7 +4,8 @@ from ..typedefs cimport hash_t
|
|||
from .kb cimport KnowledgeBase
|
||||
|
||||
|
||||
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
|
||||
# Object used by the Entity Linker that summarizes one entity-alias candidate
|
||||
# combination.
|
||||
cdef class Candidate:
|
||||
cdef readonly KnowledgeBase kb
|
||||
cdef hash_t entity_hash
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
# cython: infer_types=True
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
|
@ -8,15 +8,24 @@ from ..tokens import Span
|
|||
|
||||
|
||||
cdef class Candidate:
|
||||
"""A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved
|
||||
to a specific `entity` from a Knowledge Base. This will be used as input for the entity linking
|
||||
algorithm which will disambiguate the various candidates to the correct one.
|
||||
"""A `Candidate` object refers to a textual mention (`alias`) that may or
|
||||
may not be resolved to a specific `entity` from a Knowledge Base. This
|
||||
will be used as input for the entity linking algorithm which will
|
||||
disambiguate the various candidates to the correct one.
|
||||
Each candidate (alias, entity) pair is assigned a certain prior probability.
|
||||
|
||||
DOCS: https://spacy.io/api/kb/#candidate-init
|
||||
"""
|
||||
|
||||
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob):
|
||||
def __init__(
|
||||
self,
|
||||
KnowledgeBase kb,
|
||||
entity_hash,
|
||||
entity_freq,
|
||||
entity_vector,
|
||||
alias_hash,
|
||||
prior_prob
|
||||
):
|
||||
self.kb = kb
|
||||
self.entity_hash = entity_hash
|
||||
self.entity_freq = entity_freq
|
||||
|
@ -59,7 +68,8 @@ cdef class Candidate:
|
|||
|
||||
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for a given mention and fetching appropriate entries from the index.
|
||||
Return candidate entities for a given mention and fetching appropriate
|
||||
entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (Span): Entity mention for which to identify candidates.
|
||||
RETURNS (Iterable[Candidate]): Identified candidates.
|
||||
|
@ -67,9 +77,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
|||
return kb.get_candidates(mention)
|
||||
|
||||
|
||||
def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
|
||||
def get_candidates_batch(
|
||||
kb: KnowledgeBase, mentions: Iterable[Span]
|
||||
) -> Iterable[Iterable[Candidate]]:
|
||||
"""
|
||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||
Return candidate entities for the given mentions and fetching appropriate entries
|
||||
from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (Iterable[Span]): Entity mentions for which to identify candidates.
|
||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
# cython: infer_types=True
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple, Union
|
||||
|
@ -12,8 +12,9 @@ from .candidate import Candidate
|
|||
|
||||
|
||||
cdef class KnowledgeBase:
|
||||
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases,
|
||||
to support entity linking of named entities to real-world concepts.
|
||||
"""A `KnowledgeBase` instance stores unique identifiers for entities and
|
||||
their textual aliases, to support entity linking of named entities to
|
||||
real-world concepts.
|
||||
This is an abstract class and requires its operations to be implemented.
|
||||
|
||||
DOCS: https://spacy.io/api/kb
|
||||
|
@ -31,10 +32,13 @@ cdef class KnowledgeBase:
|
|||
self.entity_vector_length = entity_vector_length
|
||||
self.mem = Pool()
|
||||
|
||||
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
|
||||
def get_candidates_batch(
|
||||
self, mentions: Iterable[Span]
|
||||
) -> Iterable[Iterable[Candidate]]:
|
||||
"""
|
||||
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
Return candidate entities for specified texts. Each candidate defines
|
||||
the entity, the original alias, and the prior probability of that
|
||||
alias resolving to that entity.
|
||||
If no candidate is found for a given text, an empty list is returned.
|
||||
mentions (Iterable[Span]): Mentions for which to get candidates.
|
||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||
|
@ -43,14 +47,17 @@ cdef class KnowledgeBase:
|
|||
|
||||
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for specified text. Each candidate defines the entity, the original alias,
|
||||
Return candidate entities for specified text. Each candidate defines
|
||||
the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the no candidate is found for a given text, an empty list is returned.
|
||||
mention (Span): Mention for which to get candidates.
|
||||
RETURNS (Iterable[Candidate]): Identified candidates.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="get_candidates", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
|
||||
|
@ -68,7 +75,9 @@ cdef class KnowledgeBase:
|
|||
RETURNS (Iterable[float]): Vector for specified entity.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="get_vector", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def to_bytes(self, **kwargs) -> bytes:
|
||||
|
@ -76,7 +85,9 @@ cdef class KnowledgeBase:
|
|||
RETURNS (bytes): Current state as binary string.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="to_bytes", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()):
|
||||
|
@ -85,25 +96,35 @@ cdef class KnowledgeBase:
|
|||
exclude (Tuple[str]): Properties to exclude when restoring KB.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="from_bytes", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
"""
|
||||
Write KnowledgeBase content to disk.
|
||||
path (Union[str, Path]): Target file path.
|
||||
exclude (Iterable[str]): List of components to exclude.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="to_disk", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
"""
|
||||
Load KnowledgeBase content from disk.
|
||||
path (Union[str, Path]): Target file path.
|
||||
exclude (Iterable[str]): List of components to exclude.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
|
||||
Errors.E1045.format(
|
||||
parent="KnowledgeBase", method="from_disk", name=self.__name__
|
||||
)
|
||||
)
|
||||
|
|
|
@ -55,23 +55,28 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
# optional data, we can let users configure a DB as the backend for this.
|
||||
cdef object _features_table
|
||||
|
||||
|
||||
cdef inline int64_t c_add_vector(self, vector[float] entity_vector) nogil:
|
||||
"""Add an entity vector to the vectors table."""
|
||||
cdef int64_t new_index = self._vectors_table.size()
|
||||
self._vectors_table.push_back(entity_vector)
|
||||
return new_index
|
||||
|
||||
|
||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
|
||||
int32_t vector_index, int feats_row) nogil:
|
||||
cdef inline int64_t c_add_entity(
|
||||
self,
|
||||
hash_t entity_hash,
|
||||
float freq,
|
||||
int32_t vector_index,
|
||||
int feats_row
|
||||
) nogil:
|
||||
"""Add an entry to the vector of entries.
|
||||
After calling this method, make sure to update also the _entry_index using the return value"""
|
||||
After calling this method, make sure to update also the _entry_index
|
||||
using the return value"""
|
||||
# This is what we'll map the entity hash key to. It's where the entry will sit
|
||||
# in the vector of entries, so we can get it later.
|
||||
cdef int64_t new_index = self._entries.size()
|
||||
|
||||
# Avoid struct initializer to enable nogil, cf https://github.com/cython/cython/issues/1642
|
||||
# Avoid struct initializer to enable nogil, cf.
|
||||
# https://github.com/cython/cython/issues/1642
|
||||
cdef KBEntryC entry
|
||||
entry.entity_hash = entity_hash
|
||||
entry.vector_index = vector_index
|
||||
|
@ -81,11 +86,17 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
self._entries.push_back(entry)
|
||||
return new_index
|
||||
|
||||
cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs) nogil:
|
||||
"""Connect a mention to a list of potential entities with their prior probabilities .
|
||||
After calling this method, make sure to update also the _alias_index using the return value"""
|
||||
# This is what we'll map the alias hash key to. It's where the alias will be defined
|
||||
# in the vector of aliases.
|
||||
cdef inline int64_t c_add_aliases(
|
||||
self,
|
||||
hash_t alias_hash,
|
||||
vector[int64_t] entry_indices,
|
||||
vector[float] probs
|
||||
) nogil:
|
||||
"""Connect a mention to a list of potential entities with their prior
|
||||
probabilities. After calling this method, make sure to update also the
|
||||
_alias_index using the return value"""
|
||||
# This is what we'll map the alias hash key to. It's where the alias will be
|
||||
# defined in the vector of aliases.
|
||||
cdef int64_t new_index = self._aliases_table.size()
|
||||
|
||||
# Avoid struct initializer to enable nogil
|
||||
|
@ -98,8 +109,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
|
||||
cdef inline void _create_empty_vectors(self, hash_t dummy_hash) nogil:
|
||||
"""
|
||||
Initializing the vectors and making sure the first element of each vector is a dummy,
|
||||
because the PreshMap maps pointing to indices in these vectors can not contain 0 as value
|
||||
Initializing the vectors and making sure the first element of each vector is a
|
||||
dummy, because the PreshMap maps pointing to indices in these vectors can not
|
||||
contain 0 as value.
|
||||
cf. https://github.com/explosion/preshed/issues/17
|
||||
"""
|
||||
cdef int32_t dummy_value = 0
|
||||
|
@ -130,12 +142,18 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
cdef class Writer:
|
||||
cdef FILE* _fp
|
||||
|
||||
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
|
||||
cdef int write_header(
|
||||
self, int64_t nr_entries, int64_t entity_vector_length
|
||||
) except -1
|
||||
cdef int write_vector_element(self, float element) except -1
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
|
||||
cdef int write_entry(
|
||||
self, hash_t entry_hash, float entry_freq, int32_t vector_index
|
||||
) except -1
|
||||
|
||||
cdef int write_alias_length(self, int64_t alias_length) except -1
|
||||
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
|
||||
cdef int write_alias_header(
|
||||
self, hash_t alias_hash, int64_t candidate_length
|
||||
) except -1
|
||||
cdef int write_alias(self, int64_t entry_index, float prob) except -1
|
||||
|
||||
cdef int _write(self, void* value, size_t size) except -1
|
||||
|
@ -143,12 +161,18 @@ cdef class Writer:
|
|||
cdef class Reader:
|
||||
cdef FILE* _fp
|
||||
|
||||
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
|
||||
cdef int read_header(
|
||||
self, int64_t* nr_entries, int64_t* entity_vector_length
|
||||
) except -1
|
||||
cdef int read_vector_element(self, float* element) except -1
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
|
||||
cdef int read_entry(
|
||||
self, hash_t* entity_hash, float* freq, int32_t* vector_index
|
||||
) except -1
|
||||
|
||||
cdef int read_alias_length(self, int64_t* alias_length) except -1
|
||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
|
||||
cdef int read_alias_header(
|
||||
self, hash_t* alias_hash, int64_t* candidate_length
|
||||
) except -1
|
||||
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1
|
||||
|
||||
cdef int _read(self, void* value, size_t size) except -1
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
from typing import Any, Callable, Dict, Iterable, Union
|
||||
# cython: infer_types=True
|
||||
from typing import Any, Callable, Dict, Iterable
|
||||
|
||||
import srsly
|
||||
|
||||
|
@ -27,8 +27,9 @@ from .candidate import Candidate as Candidate
|
|||
|
||||
|
||||
cdef class InMemoryLookupKB(KnowledgeBase):
|
||||
"""An `InMemoryLookupKB` instance stores unique identifiers for entities and their textual aliases,
|
||||
to support entity linking of named entities to real-world concepts.
|
||||
"""An `InMemoryLookupKB` instance stores unique identifiers for entities
|
||||
and their textual aliases, to support entity linking of named entities to
|
||||
real-world concepts.
|
||||
|
||||
DOCS: https://spacy.io/api/inmemorylookupkb
|
||||
"""
|
||||
|
@ -71,7 +72,8 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
|
||||
def add_entity(self, str entity, float freq, vector[float] entity_vector):
|
||||
"""
|
||||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||
Add an entity to the KB, optionally specifying its log probability
|
||||
based on corpus frequency.
|
||||
Return the hash of the entity ID/name at the end.
|
||||
"""
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
|
@ -83,14 +85,20 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
|
||||
# Raise an error if the provided entity vector is not of the correct length
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||
raise ValueError(
|
||||
Errors.E141.format(
|
||||
found=len(entity_vector), required=self.entity_vector_length
|
||||
)
|
||||
)
|
||||
|
||||
vector_index = self.c_add_vector(entity_vector=entity_vector)
|
||||
|
||||
new_index = self.c_add_entity(entity_hash=entity_hash,
|
||||
freq=freq,
|
||||
vector_index=vector_index,
|
||||
feats_row=-1) # Features table currently not implemented
|
||||
new_index = self.c_add_entity(
|
||||
entity_hash=entity_hash,
|
||||
freq=freq,
|
||||
vector_index=vector_index,
|
||||
feats_row=-1
|
||||
) # Features table currently not implemented
|
||||
self._entry_index[entity_hash] = new_index
|
||||
|
||||
return entity_hash
|
||||
|
@ -115,7 +123,12 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
else:
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||
raise ValueError(
|
||||
Errors.E141.format(
|
||||
found=len(entity_vector),
|
||||
required=self.entity_vector_length
|
||||
)
|
||||
)
|
||||
|
||||
entry.entity_hash = entity_hash
|
||||
entry.freq = freq_list[i]
|
||||
|
@ -149,11 +162,15 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
previous_alias_nr = self.get_size_aliases()
|
||||
# Throw an error if the length of entities and probabilities are not the same
|
||||
if not len(entities) == len(probabilities):
|
||||
raise ValueError(Errors.E132.format(alias=alias,
|
||||
entities_length=len(entities),
|
||||
probabilities_length=len(probabilities)))
|
||||
raise ValueError(
|
||||
Errors.E132.format(
|
||||
alias=alias,
|
||||
entities_length=len(entities),
|
||||
probabilities_length=len(probabilities))
|
||||
)
|
||||
|
||||
# Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors)
|
||||
# Throw an error if the probabilities sum up to more than 1 (allow for
|
||||
# some rounding errors)
|
||||
prob_sum = sum(probabilities)
|
||||
if prob_sum > 1.00001:
|
||||
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
||||
|
@ -170,40 +187,47 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
|
||||
for entity, prob in zip(entities, probabilities):
|
||||
entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
if entity_hash not in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
entry_indices.push_back(int(entry_index))
|
||||
probs.push_back(float(prob))
|
||||
|
||||
new_index = self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
|
||||
new_index = self.c_add_aliases(
|
||||
alias_hash=alias_hash, entry_indices=entry_indices, probs=probs
|
||||
)
|
||||
self._alias_index[alias_hash] = new_index
|
||||
|
||||
if previous_alias_nr + 1 != self.get_size_aliases():
|
||||
raise RuntimeError(Errors.E891.format(alias=alias))
|
||||
return alias_hash
|
||||
|
||||
def append_alias(self, str alias, str entity, float prior_prob, ignore_warnings=False):
|
||||
def append_alias(
|
||||
self, str alias, str entity, float prior_prob, ignore_warnings=False
|
||||
):
|
||||
"""
|
||||
For an alias already existing in the KB, extend its potential entities with one more.
|
||||
For an alias already existing in the KB, extend its potential entities
|
||||
with one more.
|
||||
Throw a warning if either the alias or the entity is unknown,
|
||||
or when the combination is already previously recorded.
|
||||
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||
For efficiency, it's best to use the method `add_alias` as much as
|
||||
possible instead of this one.
|
||||
"""
|
||||
# Check if the alias exists in the KB
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
if alias_hash not in self._alias_index:
|
||||
raise ValueError(Errors.E176.format(alias=alias))
|
||||
|
||||
# Check if the entity exists in the KB
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
if entity_hash not in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
|
||||
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
|
||||
# Throw an error if the prior probabilities (including the new one)
|
||||
# sum up to more than 1
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
current_sum = sum([p for p in alias_entry.probs])
|
||||
|
@ -236,12 +260,13 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
|
||||
def get_alias_candidates(self, str alias) -> Iterable[Candidate]:
|
||||
"""
|
||||
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
Return candidate entities for an alias. Each candidate defines the
|
||||
entity, the original alias, and the prior probability of that alias
|
||||
resolving to that entity.
|
||||
If the alias is not known in the KB, and empty list is returned.
|
||||
"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
if alias_hash not in self._alias_index:
|
||||
return []
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
@ -249,10 +274,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
return [Candidate(kb=self,
|
||||
entity_hash=self._entries[entry_index].entity_hash,
|
||||
entity_freq=self._entries[entry_index].freq,
|
||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||
entity_vector=self._vectors_table[
|
||||
self._entries[entry_index].vector_index
|
||||
],
|
||||
alias_hash=alias_hash,
|
||||
prior_prob=prior_prob)
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
for (entry_index, prior_prob) in zip(
|
||||
alias_entry.entry_indices, alias_entry.probs
|
||||
)
|
||||
if entry_index != 0]
|
||||
|
||||
def get_vector(self, str entity):
|
||||
|
@ -266,8 +295,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||
|
||||
def get_prior_prob(self, str entity, str alias):
|
||||
""" Return the prior probability of a given alias being linked to a given entity,
|
||||
or return 0.0 when this combination is not known in the knowledge base"""
|
||||
""" Return the prior probability of a given alias being linked to a
|
||||
given entity, or return 0.0 when this combination is not known in the
|
||||
knowledge base."""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
|
@ -278,7 +308,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
|
||||
for (entry_index, prior_prob) in zip(
|
||||
alias_entry.entry_indices, alias_entry.probs
|
||||
):
|
||||
if self._entries[entry_index].entity_hash == entity_hash:
|
||||
return prior_prob
|
||||
|
||||
|
@ -288,13 +320,19 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
"""Serialize the current state to a binary string.
|
||||
"""
|
||||
def serialize_header():
|
||||
header = (self.get_size_entities(), self.get_size_aliases(), self.entity_vector_length)
|
||||
header = (
|
||||
self.get_size_entities(),
|
||||
self.get_size_aliases(),
|
||||
self.entity_vector_length
|
||||
)
|
||||
return srsly.json_dumps(header)
|
||||
|
||||
def serialize_entries():
|
||||
i = 1
|
||||
tuples = []
|
||||
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
|
||||
for entry_hash, entry_index in sorted(
|
||||
self._entry_index.items(), key=lambda x: x[1]
|
||||
):
|
||||
entry = self._entries[entry_index]
|
||||
assert entry.entity_hash == entry_hash
|
||||
assert entry_index == i
|
||||
|
@ -307,7 +345,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
headers = []
|
||||
indices_lists = []
|
||||
probs_lists = []
|
||||
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
|
||||
for alias_hash, alias_index in sorted(
|
||||
self._alias_index.items(), key=lambda x: x[1]
|
||||
):
|
||||
alias = self._aliases_table[alias_index]
|
||||
assert alias_index == i
|
||||
candidate_length = len(alias.entry_indices)
|
||||
|
@ -365,7 +405,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
indices = srsly.json_loads(all_data[1])
|
||||
probs = srsly.json_loads(all_data[2])
|
||||
for header, indices, probs in zip(headers, indices, probs):
|
||||
alias_hash, candidate_length = header
|
||||
alias_hash, _candidate_length = header
|
||||
alias.entry_indices = indices
|
||||
alias.probs = probs
|
||||
self._aliases_table[i] = alias
|
||||
|
@ -414,10 +454,14 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
writer.write_vector_element(element)
|
||||
i = i+1
|
||||
|
||||
# dumping the entry records in the order in which they are in the _entries vector.
|
||||
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||
# dumping the entry records in the order in which they are in the
|
||||
# _entries vector.
|
||||
# index 0 is a dummy object not stored in the _entry_index and can
|
||||
# be ignored.
|
||||
i = 1
|
||||
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
|
||||
for entry_hash, entry_index in sorted(
|
||||
self._entry_index.items(), key=lambda x: x[1]
|
||||
):
|
||||
entry = self._entries[entry_index]
|
||||
assert entry.entity_hash == entry_hash
|
||||
assert entry_index == i
|
||||
|
@ -429,7 +473,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
# dumping the aliases in the order in which they are in the _alias_index vector.
|
||||
# index 0 is a dummy object not stored in the _aliases_table and can be ignored.
|
||||
i = 1
|
||||
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
|
||||
for alias_hash, alias_index in sorted(
|
||||
self._alias_index.items(), key=lambda x: x[1]
|
||||
):
|
||||
alias = self._aliases_table[alias_index]
|
||||
assert alias_index == i
|
||||
|
||||
|
@ -535,7 +581,8 @@ cdef class Writer:
|
|||
def __init__(self, path):
|
||||
assert isinstance(path, Path)
|
||||
content = bytes(path)
|
||||
cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content
|
||||
cdef bytes bytes_loc = content.encode('utf8') \
|
||||
if type(content) == str else content
|
||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||
if not self._fp:
|
||||
raise IOError(Errors.E146.format(path=path))
|
||||
|
@ -545,14 +592,18 @@ cdef class Writer:
|
|||
cdef size_t status = fclose(self._fp)
|
||||
assert status == 0
|
||||
|
||||
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1:
|
||||
cdef int write_header(
|
||||
self, int64_t nr_entries, int64_t entity_vector_length
|
||||
) except -1:
|
||||
self._write(&nr_entries, sizeof(nr_entries))
|
||||
self._write(&entity_vector_length, sizeof(entity_vector_length))
|
||||
|
||||
cdef int write_vector_element(self, float element) except -1:
|
||||
self._write(&element, sizeof(element))
|
||||
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
|
||||
cdef int write_entry(
|
||||
self, hash_t entry_hash, float entry_freq, int32_t vector_index
|
||||
) except -1:
|
||||
self._write(&entry_hash, sizeof(entry_hash))
|
||||
self._write(&entry_freq, sizeof(entry_freq))
|
||||
self._write(&vector_index, sizeof(vector_index))
|
||||
|
@ -561,7 +612,9 @@ cdef class Writer:
|
|||
cdef int write_alias_length(self, int64_t alias_length) except -1:
|
||||
self._write(&alias_length, sizeof(alias_length))
|
||||
|
||||
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1:
|
||||
cdef int write_alias_header(
|
||||
self, hash_t alias_hash, int64_t candidate_length
|
||||
) except -1:
|
||||
self._write(&alias_hash, sizeof(alias_hash))
|
||||
self._write(&candidate_length, sizeof(candidate_length))
|
||||
|
||||
|
@ -577,16 +630,19 @@ cdef class Writer:
|
|||
cdef class Reader:
|
||||
def __init__(self, path):
|
||||
content = bytes(path)
|
||||
cdef bytes bytes_loc = content.encode('utf8') if type(content) == str else content
|
||||
cdef bytes bytes_loc = content.encode('utf8') \
|
||||
if type(content) == str else content
|
||||
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||
if not self._fp:
|
||||
PyErr_SetFromErrno(IOError)
|
||||
status = fseek(self._fp, 0, 0) # this can be 0 if there is no header
|
||||
fseek(self._fp, 0, 0) # this can be 0 if there is no header
|
||||
|
||||
def __dealloc__(self):
|
||||
fclose(self._fp)
|
||||
|
||||
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1:
|
||||
cdef int read_header(
|
||||
self, int64_t* nr_entries, int64_t* entity_vector_length
|
||||
) except -1:
|
||||
status = self._read(nr_entries, sizeof(int64_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
|
@ -606,7 +662,9 @@ cdef class Reader:
|
|||
return 0 # end of file
|
||||
raise IOError(Errors.E145.format(param="vector element"))
|
||||
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
|
||||
cdef int read_entry(
|
||||
self, hash_t* entity_hash, float* freq, int32_t* vector_index
|
||||
) except -1:
|
||||
status = self._read(entity_hash, sizeof(hash_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
|
@ -637,7 +695,9 @@ cdef class Reader:
|
|||
return 0 # end of file
|
||||
raise IOError(Errors.E145.format(param="alias length"))
|
||||
|
||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
|
||||
cdef int read_alias_header(
|
||||
self, hash_t* alias_hash, int64_t* candidate_length
|
||||
) except -1:
|
||||
status = self._read(alias_hash, sizeof(hash_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
|
|
|
@ -163,7 +163,7 @@ class SpanishLemmatizer(Lemmatizer):
|
|||
for old, new in self.lookups.get_table("lemma_rules").get("det", []):
|
||||
if word == old:
|
||||
return [new]
|
||||
# If none of the specfic rules apply, search in the common rules for
|
||||
# If none of the specific rules apply, search in the common rules for
|
||||
# determiners and pronouns that follow a unique pattern for
|
||||
# lemmatization. If the word is in the list, return the corresponding
|
||||
# lemma.
|
||||
|
@ -291,7 +291,7 @@ class SpanishLemmatizer(Lemmatizer):
|
|||
for old, new in self.lookups.get_table("lemma_rules").get("pron", []):
|
||||
if word == old:
|
||||
return [new]
|
||||
# If none of the specfic rules apply, search in the common rules for
|
||||
# If none of the specific rules apply, search in the common rules for
|
||||
# determiners and pronouns that follow a unique pattern for
|
||||
# lemmatization. If the word is in the list, return the corresponding
|
||||
# lemma.
|
||||
|
|
|
@ -15,6 +15,7 @@ _prefixes = (
|
|||
[
|
||||
"†",
|
||||
"⸏",
|
||||
"〈",
|
||||
]
|
||||
+ LIST_PUNCT
|
||||
+ LIST_ELLIPSES
|
||||
|
@ -31,6 +32,7 @@ _suffixes = (
|
|||
+ [
|
||||
"†",
|
||||
"⸎",
|
||||
"〉",
|
||||
r"(?<=[\u1F00-\u1FFF\u0370-\u03FF])[\-\.⸏]",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -15,4 +15,7 @@ sentences = [
|
|||
"Türkiye'nin başkenti neresi?",
|
||||
"Bakanlar Kurulu 180 günlük eylem planını açıkladı.",
|
||||
"Merkez Bankası, beklentiler doğrultusunda faizlerde değişikliğe gitmedi.",
|
||||
"Cemal Sureya kimdir?",
|
||||
"Bunlari Biliyor muydunuz?",
|
||||
"Altinoluk Turkiye haritasinin neresinde yer alir?",
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
|
@ -64,6 +65,7 @@ from .util import (
|
|||
registry,
|
||||
warn_if_jupyter_cupy,
|
||||
)
|
||||
from .vectors import BaseVectors
|
||||
from .vocab import Vocab, create_vocab
|
||||
|
||||
PipeCallable = Callable[[Doc], Doc]
|
||||
|
@ -157,6 +159,7 @@ class Language:
|
|||
max_length: int = 10**6,
|
||||
meta: Dict[str, Any] = {},
|
||||
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
||||
create_vectors: Optional[Callable[["Vocab"], BaseVectors]] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
@ -197,6 +200,10 @@ class Language:
|
|||
if vocab is True:
|
||||
vectors_name = meta.get("vectors", {}).get("name")
|
||||
vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name)
|
||||
if not create_vectors:
|
||||
vectors_cfg = {"vectors": self._config["nlp"]["vectors"]}
|
||||
create_vectors = registry.resolve(vectors_cfg)["vectors"]
|
||||
vocab.vectors = create_vectors(vocab)
|
||||
else:
|
||||
if (self.lang and vocab.lang) and (self.lang != vocab.lang):
|
||||
raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang))
|
||||
|
@ -1764,6 +1771,10 @@ class Language:
|
|||
).merge(config)
|
||||
if "nlp" not in config:
|
||||
raise ValueError(Errors.E985.format(config=config))
|
||||
# fill in [nlp.vectors] if not present (as a narrower alternative to
|
||||
# auto-filling [nlp] from the default config)
|
||||
if "vectors" not in config["nlp"]:
|
||||
config["nlp"]["vectors"] = {"@vectors": "spacy.Vectors.v1"}
|
||||
config_lang = config["nlp"].get("lang")
|
||||
if config_lang is not None and config_lang != cls.lang:
|
||||
raise ValueError(
|
||||
|
@ -1795,6 +1806,7 @@ class Language:
|
|||
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
||||
)
|
||||
create_tokenizer = resolved_nlp["tokenizer"]
|
||||
create_vectors = resolved_nlp["vectors"]
|
||||
before_creation = resolved_nlp["before_creation"]
|
||||
after_creation = resolved_nlp["after_creation"]
|
||||
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
||||
|
@ -1815,7 +1827,12 @@ class Language:
|
|||
# inside stuff like the spacy train function. If we loaded them here,
|
||||
# then we would load them twice at runtime: once when we make from config,
|
||||
# and then again when we load from disk.
|
||||
nlp = lang_cls(vocab=vocab, create_tokenizer=create_tokenizer, meta=meta)
|
||||
nlp = lang_cls(
|
||||
vocab=vocab,
|
||||
create_tokenizer=create_tokenizer,
|
||||
create_vectors=create_vectors,
|
||||
meta=meta,
|
||||
)
|
||||
if after_creation is not None:
|
||||
nlp = after_creation(nlp)
|
||||
if not isinstance(nlp, cls):
|
||||
|
@ -1825,7 +1842,6 @@ class Language:
|
|||
# Later we replace the component config with the raw config again.
|
||||
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
||||
pipeline = interpolated.get("components", {})
|
||||
sourced = util.get_sourced_components(interpolated)
|
||||
# If components are loaded from a source (existing models), we cache
|
||||
# them here so they're only loaded once
|
||||
source_nlps = {}
|
||||
|
@ -1958,7 +1974,7 @@ class Language:
|
|||
useful when training a pipeline with components sourced from an existing
|
||||
pipeline: if multiple components (e.g. tagger, parser, NER) listen to
|
||||
the same tok2vec component, but some of them are frozen and not updated,
|
||||
their performance may degrade significally as the tok2vec component is
|
||||
their performance may degrade significantly as the tok2vec component is
|
||||
updated with new data. To prevent this, listeners can be replaced with
|
||||
a standalone tok2vec layer that is owned by the component and doesn't
|
||||
change if the component isn't updated.
|
||||
|
@ -2033,8 +2049,20 @@ class Language:
|
|||
# Go over the listener layers and replace them
|
||||
for listener in pipe_listeners:
|
||||
new_model = tok2vec_model.copy()
|
||||
if "replace_listener" in tok2vec_model.attrs:
|
||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||
replace_listener_func = tok2vec_model.attrs.get("replace_listener")
|
||||
if replace_listener_func is not None:
|
||||
# Pass the extra args to the callback without breaking compatibility with
|
||||
# old library versions that only expect a single parameter.
|
||||
num_params = len(
|
||||
inspect.signature(replace_listener_func).parameters
|
||||
)
|
||||
if num_params == 1:
|
||||
new_model = replace_listener_func(new_model)
|
||||
elif num_params == 3:
|
||||
new_model = replace_listener_func(new_model, listener, tok2vec)
|
||||
else:
|
||||
raise ValueError(Errors.E1055.format(num_params=num_params))
|
||||
|
||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# cython: embedsignature=True
|
||||
# cython: profile=False
|
||||
# Compiler crashes on memory view coercion without this. Should report bug.
|
||||
cimport numpy as np
|
||||
from cython.view cimport array as cvarray
|
||||
from libc.string cimport memset
|
||||
|
||||
np.import_array()
|
||||
|
@ -35,7 +35,7 @@ from .typedefs cimport attr_t, flags_t
|
|||
from .attrs import intify_attrs
|
||||
from .errors import Errors, Warnings
|
||||
|
||||
OOV_RANK = 0xffffffffffffffff # UINT64_MAX
|
||||
OOV_RANK = 0xffffffffffffffff # UINT64_MAX
|
||||
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
|
||||
EMPTY_LEXEME.id = OOV_RANK
|
||||
|
||||
|
@ -105,7 +105,7 @@ cdef class Lexeme:
|
|||
if isinstance(value, float):
|
||||
continue
|
||||
elif isinstance(value, (int, long)):
|
||||
Lexeme.set_struct_attr(self.c, attr, value)
|
||||
Lexeme.set_struct_attr(self.c, attr, value)
|
||||
else:
|
||||
Lexeme.set_struct_attr(self.c, attr, self.vocab.strings.add(value))
|
||||
|
||||
|
@ -137,10 +137,12 @@ cdef class Lexeme:
|
|||
if hasattr(other, "orth"):
|
||||
if self.c.orth == other.orth:
|
||||
return 1.0
|
||||
elif hasattr(other, "__len__") and len(other) == 1 \
|
||||
and hasattr(other[0], "orth"):
|
||||
if self.c.orth == other[0].orth:
|
||||
return 1.0
|
||||
elif (
|
||||
hasattr(other, "__len__") and len(other) == 1
|
||||
and hasattr(other[0], "orth")
|
||||
and self.c.orth == other[0].orth
|
||||
):
|
||||
return 1.0
|
||||
if self.vector_norm == 0 or other.vector_norm == 0:
|
||||
warnings.warn(Warnings.W008.format(obj="Lexeme"))
|
||||
return 0.0
|
||||
|
@ -149,7 +151,7 @@ cdef class Lexeme:
|
|||
result = xp.dot(vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
# ensure we get a scalar back (numpy does this automatically but cupy doesn't)
|
||||
return result.item()
|
||||
|
||||
|
||||
@property
|
||||
def has_vector(self):
|
||||
"""RETURNS (bool): Whether a word vector is associated with the object.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
# cython: infer_types=True
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
|
@ -108,7 +108,7 @@ cdef class DependencyMatcher:
|
|||
key (str): The match ID.
|
||||
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
||||
"""
|
||||
return self.has_key(key)
|
||||
return self.has_key(key) # no-cython-lint: W601
|
||||
|
||||
def _validate_input(self, pattern, key):
|
||||
idx = 0
|
||||
|
@ -129,6 +129,7 @@ cdef class DependencyMatcher:
|
|||
else:
|
||||
required_keys = {"RIGHT_ID", "RIGHT_ATTRS", "REL_OP", "LEFT_ID"}
|
||||
relation_keys = set(relation.keys())
|
||||
# Identify required keys that have not been specified
|
||||
missing = required_keys - relation_keys
|
||||
if missing:
|
||||
missing_txt = ", ".join(list(missing))
|
||||
|
@ -136,6 +137,13 @@ cdef class DependencyMatcher:
|
|||
required=required_keys,
|
||||
missing=missing_txt
|
||||
))
|
||||
# Identify additional, unsupported keys
|
||||
unsupported = relation_keys - required_keys
|
||||
if unsupported:
|
||||
unsupported_txt = ", ".join(list(unsupported))
|
||||
warnings.warn(Warnings.W126.format(
|
||||
unsupported=unsupported_txt
|
||||
))
|
||||
if (
|
||||
relation["RIGHT_ID"] in visited_nodes
|
||||
or relation["LEFT_ID"] not in visited_nodes
|
||||
|
@ -264,7 +272,7 @@ cdef class DependencyMatcher:
|
|||
|
||||
def remove(self, key):
|
||||
key = self._normalize_key(key)
|
||||
if not key in self._patterns:
|
||||
if key not in self._patterns:
|
||||
raise ValueError(Errors.E175.format(key=key))
|
||||
self._patterns.pop(key)
|
||||
self._raw_patterns.pop(key)
|
||||
|
@ -382,7 +390,7 @@ cdef class DependencyMatcher:
|
|||
return []
|
||||
return [doc[node].head]
|
||||
|
||||
def _gov(self,doc,node):
|
||||
def _gov(self, doc, node):
|
||||
return list(doc[node].children)
|
||||
|
||||
def _dep_chain(self, doc, node):
|
||||
|
@ -443,7 +451,7 @@ cdef class DependencyMatcher:
|
|||
|
||||
def _right_child(self, doc, node):
|
||||
return [child for child in doc[node].rights]
|
||||
|
||||
|
||||
def _left_child(self, doc, node):
|
||||
return [child for child in doc[node].lefts]
|
||||
|
||||
|
@ -461,7 +469,7 @@ cdef class DependencyMatcher:
|
|||
if doc[node].head.i > node:
|
||||
return [doc[node].head]
|
||||
return []
|
||||
|
||||
|
||||
def _left_parent(self, doc, node):
|
||||
if doc[node].head.i < node:
|
||||
return [doc[node].head]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: profile=True, binding=True, infer_types=True
|
||||
# cython: binding=True, infer_types=True
|
||||
from cpython.object cimport PyObject
|
||||
from libc.stdint cimport int64_t
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: binding=True, infer_types=True, profile=True
|
||||
# cython: binding=True, infer_types=True
|
||||
from typing import Iterable, List
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
|
@ -12,31 +12,18 @@ import warnings
|
|||
|
||||
import srsly
|
||||
|
||||
from ..attrs cimport (
|
||||
DEP,
|
||||
ENT_IOB,
|
||||
ID,
|
||||
LEMMA,
|
||||
MORPH,
|
||||
NULL_ATTR,
|
||||
ORTH,
|
||||
POS,
|
||||
TAG,
|
||||
attr_id_t,
|
||||
)
|
||||
from ..attrs cimport DEP, ENT_IOB, ID, LEMMA, MORPH, NULL_ATTR, POS, TAG
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.doc cimport Doc, get_token_attr_for_matcher
|
||||
from ..tokens.morphanalysis cimport MorphAnalysis
|
||||
from ..tokens.span cimport Span
|
||||
from ..tokens.token cimport Token
|
||||
from ..typedefs cimport attr_t
|
||||
from ..vocab cimport Vocab
|
||||
|
||||
from ..attrs import IDS
|
||||
from ..errors import Errors, MatchPatternError, Warnings
|
||||
from ..schemas import validate_token_pattern
|
||||
from ..strings import get_string_id
|
||||
from ..util import registry
|
||||
from .levenshtein import levenshtein_compare
|
||||
|
||||
DEF PADDING = 5
|
||||
|
@ -87,9 +74,9 @@ cdef class Matcher:
|
|||
key (str): The match ID.
|
||||
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
||||
"""
|
||||
return self.has_key(key)
|
||||
return self.has_key(key) # no-cython-lint: W601
|
||||
|
||||
def add(self, key, patterns, *, on_match=None, greedy: str=None):
|
||||
def add(self, key, patterns, *, on_match=None, greedy: str = None):
|
||||
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
||||
key, an on_match callback, and one or more patterns.
|
||||
|
||||
|
@ -143,8 +130,13 @@ cdef class Matcher:
|
|||
key = self._normalize_key(key)
|
||||
for pattern in patterns:
|
||||
try:
|
||||
specs = _preprocess_pattern(pattern, self.vocab,
|
||||
self._extensions, self._extra_predicates, self._fuzzy_compare)
|
||||
specs = _preprocess_pattern(
|
||||
pattern,
|
||||
self.vocab,
|
||||
self._extensions,
|
||||
self._extra_predicates,
|
||||
self._fuzzy_compare
|
||||
)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
for spec in specs:
|
||||
for attr, _ in spec[1]:
|
||||
|
@ -168,7 +160,7 @@ cdef class Matcher:
|
|||
key (str): The ID of the match rule.
|
||||
"""
|
||||
norm_key = self._normalize_key(key)
|
||||
if not norm_key in self._patterns:
|
||||
if norm_key not in self._patterns:
|
||||
raise ValueError(Errors.E175.format(key=key))
|
||||
self._patterns.pop(norm_key)
|
||||
self._callbacks.pop(norm_key)
|
||||
|
@ -268,8 +260,15 @@ cdef class Matcher:
|
|||
if self.patterns.empty():
|
||||
matches = []
|
||||
else:
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
|
||||
extensions=self._extensions, predicates=self._extra_predicates, with_alignments=with_alignments)
|
||||
matches = find_matches(
|
||||
&self.patterns[0],
|
||||
self.patterns.size(),
|
||||
doclike,
|
||||
length,
|
||||
extensions=self._extensions,
|
||||
predicates=self._extra_predicates,
|
||||
with_alignments=with_alignments
|
||||
)
|
||||
final_matches = []
|
||||
pairs_by_id = {}
|
||||
# For each key, either add all matches, or only the filtered,
|
||||
|
@ -289,9 +288,9 @@ cdef class Matcher:
|
|||
memset(matched, 0, length * sizeof(matched[0]))
|
||||
span_filter = self._filter.get(key)
|
||||
if span_filter == "FIRST":
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
|
||||
elif span_filter == "LONGEST":
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
|
||||
else:
|
||||
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
|
||||
for match in sorted_pairs:
|
||||
|
@ -366,7 +365,6 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
|
|||
cdef vector[MatchC] matches
|
||||
cdef vector[vector[MatchAlignmentC]] align_states
|
||||
cdef vector[vector[MatchAlignmentC]] align_matches
|
||||
cdef PatternStateC state
|
||||
cdef int i, j, nr_extra_attr
|
||||
cdef Pool mem = Pool()
|
||||
output = []
|
||||
|
@ -388,14 +386,22 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
|
|||
value = token.vocab.strings[value]
|
||||
extra_attr_values[i * nr_extra_attr + index] = value
|
||||
# Main loop
|
||||
cdef int nr_predicate = len(predicates)
|
||||
for i in range(length):
|
||||
for j in range(n):
|
||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||
if with_alignments != 0:
|
||||
align_states.resize(states.size())
|
||||
transition_states(states, matches, align_states, align_matches, predicate_cache,
|
||||
doclike[i], extra_attr_values, predicates, with_alignments)
|
||||
transition_states(
|
||||
states,
|
||||
matches,
|
||||
align_states,
|
||||
align_matches,
|
||||
predicate_cache,
|
||||
doclike[i],
|
||||
extra_attr_values,
|
||||
predicates,
|
||||
with_alignments
|
||||
)
|
||||
extra_attr_values += nr_extra_attr
|
||||
predicate_cache += len(predicates)
|
||||
# Handle matches that end in 0-width patterns
|
||||
|
@ -421,18 +427,28 @@ cdef find_matches(TokenPatternC** patterns, int n, object doclike, int length, e
|
|||
return output
|
||||
|
||||
|
||||
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||
vector[vector[MatchAlignmentC]]& align_states, vector[vector[MatchAlignmentC]]& align_matches,
|
||||
int8_t* cached_py_predicates,
|
||||
Token token, const attr_t* extra_attrs, py_predicates, bint with_alignments) except *:
|
||||
cdef void transition_states(
|
||||
vector[PatternStateC]& states,
|
||||
vector[MatchC]& matches,
|
||||
vector[vector[MatchAlignmentC]]& align_states,
|
||||
vector[vector[MatchAlignmentC]]& align_matches,
|
||||
int8_t* cached_py_predicates,
|
||||
Token token,
|
||||
const attr_t* extra_attrs,
|
||||
py_predicates,
|
||||
bint with_alignments
|
||||
) except *:
|
||||
cdef int q = 0
|
||||
cdef vector[PatternStateC] new_states
|
||||
cdef vector[vector[MatchAlignmentC]] align_new_states
|
||||
cdef int nr_predicate = len(py_predicates)
|
||||
for i in range(states.size()):
|
||||
if states[i].pattern.nr_py >= 1:
|
||||
update_predicate_cache(cached_py_predicates,
|
||||
states[i].pattern, token, py_predicates)
|
||||
update_predicate_cache(
|
||||
cached_py_predicates,
|
||||
states[i].pattern,
|
||||
token,
|
||||
py_predicates
|
||||
)
|
||||
action = get_action(states[i], token.c, extra_attrs,
|
||||
cached_py_predicates)
|
||||
if action == REJECT:
|
||||
|
@ -468,8 +484,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
align_new_states.push_back(align_states[q])
|
||||
states[q].pattern += 1
|
||||
if states[q].pattern.nr_py != 0:
|
||||
update_predicate_cache(cached_py_predicates,
|
||||
states[q].pattern, token, py_predicates)
|
||||
update_predicate_cache(
|
||||
cached_py_predicates,
|
||||
states[q].pattern,
|
||||
token,
|
||||
py_predicates
|
||||
)
|
||||
action = get_action(states[q], token.c, extra_attrs,
|
||||
cached_py_predicates)
|
||||
# Update alignment before the transition of current state
|
||||
|
@ -485,8 +505,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
ent_id = get_ent_id(state.pattern)
|
||||
if action == MATCH:
|
||||
matches.push_back(
|
||||
MatchC(pattern_id=ent_id, start=state.start,
|
||||
length=state.length+1))
|
||||
MatchC(
|
||||
pattern_id=ent_id,
|
||||
start=state.start,
|
||||
length=state.length+1
|
||||
)
|
||||
)
|
||||
# `align_matches` always corresponds to `matches` 1:1
|
||||
if with_alignments != 0:
|
||||
align_matches.push_back(align_states[q])
|
||||
|
@ -494,23 +518,35 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
# push match without last token if length > 0
|
||||
if state.length > 0:
|
||||
matches.push_back(
|
||||
MatchC(pattern_id=ent_id, start=state.start,
|
||||
length=state.length))
|
||||
MatchC(
|
||||
pattern_id=ent_id,
|
||||
start=state.start,
|
||||
length=state.length
|
||||
)
|
||||
)
|
||||
# MATCH_DOUBLE emits matches twice,
|
||||
# add one more to align_matches in order to keep 1:1 relationship
|
||||
if with_alignments != 0:
|
||||
align_matches.push_back(align_states[q])
|
||||
# push match with last token
|
||||
matches.push_back(
|
||||
MatchC(pattern_id=ent_id, start=state.start,
|
||||
length=state.length+1))
|
||||
MatchC(
|
||||
pattern_id=ent_id,
|
||||
start=state.start,
|
||||
length=state.length + 1
|
||||
)
|
||||
)
|
||||
# `align_matches` always corresponds to `matches` 1:1
|
||||
if with_alignments != 0:
|
||||
align_matches.push_back(align_states[q])
|
||||
elif action == MATCH_REJECT:
|
||||
matches.push_back(
|
||||
MatchC(pattern_id=ent_id, start=state.start,
|
||||
length=state.length))
|
||||
MatchC(
|
||||
pattern_id=ent_id,
|
||||
start=state.start,
|
||||
length=state.length
|
||||
)
|
||||
)
|
||||
# `align_matches` always corresponds to `matches` 1:1
|
||||
if with_alignments != 0:
|
||||
align_matches.push_back(align_states[q])
|
||||
|
@ -533,8 +569,12 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
align_states.push_back(align_new_states[i])
|
||||
|
||||
|
||||
cdef int update_predicate_cache(int8_t* cache,
|
||||
const TokenPatternC* pattern, Token token, predicates) except -1:
|
||||
cdef int update_predicate_cache(
|
||||
int8_t* cache,
|
||||
const TokenPatternC* pattern,
|
||||
Token token,
|
||||
predicates
|
||||
) except -1:
|
||||
# If the state references any extra predicates, check whether they match.
|
||||
# These are cached, so that we don't call these potentially expensive
|
||||
# Python functions more than we need to.
|
||||
|
@ -580,10 +620,12 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states,
|
|||
else:
|
||||
state.pattern += 1
|
||||
|
||||
|
||||
cdef action_t get_action(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const int8_t* predicate_matches) nogil:
|
||||
cdef action_t get_action(
|
||||
PatternStateC state,
|
||||
const TokenC * token,
|
||||
const attr_t * extra_attrs,
|
||||
const int8_t * predicate_matches
|
||||
) nogil:
|
||||
"""We need to consider:
|
||||
a) Does the token match the specification? [Yes, No]
|
||||
b) What's the quantifier? [1, 0+, ?]
|
||||
|
@ -649,53 +691,56 @@ cdef action_t get_action(PatternStateC state,
|
|||
is_match = not is_match
|
||||
quantifier = ONE
|
||||
if quantifier == ONE:
|
||||
if is_match and is_final:
|
||||
# Yes, final: 1000
|
||||
return MATCH
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0100
|
||||
return ADVANCE
|
||||
elif not is_match and is_final:
|
||||
# No, final: 0000
|
||||
return REJECT
|
||||
else:
|
||||
return REJECT
|
||||
if is_match and is_final:
|
||||
# Yes, final: 1000
|
||||
return MATCH
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0100
|
||||
return ADVANCE
|
||||
elif not is_match and is_final:
|
||||
# No, final: 0000
|
||||
return REJECT
|
||||
else:
|
||||
return REJECT
|
||||
elif quantifier == ZERO_PLUS:
|
||||
if is_match and is_final:
|
||||
# Yes, final: 1001
|
||||
return MATCH_EXTEND
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0011
|
||||
return RETRY_EXTEND
|
||||
elif not is_match and is_final:
|
||||
# No, final 2000 (note: Don't include last token!)
|
||||
return MATCH_REJECT
|
||||
else:
|
||||
# No, non-final 0010
|
||||
return RETRY
|
||||
if is_match and is_final:
|
||||
# Yes, final: 1001
|
||||
return MATCH_EXTEND
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0011
|
||||
return RETRY_EXTEND
|
||||
elif not is_match and is_final:
|
||||
# No, final 2000 (note: Don't include last token!)
|
||||
return MATCH_REJECT
|
||||
else:
|
||||
# No, non-final 0010
|
||||
return RETRY
|
||||
elif quantifier == ZERO_ONE:
|
||||
if is_match and is_final:
|
||||
# Yes, final: 3000
|
||||
# To cater for a pattern ending in "?", we need to add
|
||||
# a match both with and without the last token
|
||||
return MATCH_DOUBLE
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0110
|
||||
# We need both branches here, consider a pair like:
|
||||
# pattern: .?b string: b
|
||||
# If we 'ADVANCE' on the .?, we miss the match.
|
||||
return RETRY_ADVANCE
|
||||
elif not is_match and is_final:
|
||||
# No, final 2000 (note: Don't include last token!)
|
||||
return MATCH_REJECT
|
||||
else:
|
||||
# No, non-final 0010
|
||||
return RETRY
|
||||
if is_match and is_final:
|
||||
# Yes, final: 3000
|
||||
# To cater for a pattern ending in "?", we need to add
|
||||
# a match both with and without the last token
|
||||
return MATCH_DOUBLE
|
||||
elif is_match and not is_final:
|
||||
# Yes, non-final: 0110
|
||||
# We need both branches here, consider a pair like:
|
||||
# pattern: .?b string: b
|
||||
# If we 'ADVANCE' on the .?, we miss the match.
|
||||
return RETRY_ADVANCE
|
||||
elif not is_match and is_final:
|
||||
# No, final 2000 (note: Don't include last token!)
|
||||
return MATCH_REJECT
|
||||
else:
|
||||
# No, non-final 0010
|
||||
return RETRY
|
||||
|
||||
|
||||
cdef int8_t get_is_match(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const int8_t* predicate_matches) nogil:
|
||||
cdef int8_t get_is_match(
|
||||
PatternStateC state,
|
||||
const TokenC* token,
|
||||
const attr_t* extra_attrs,
|
||||
const int8_t* predicate_matches
|
||||
) nogil:
|
||||
for i in range(state.pattern.nr_py):
|
||||
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||
return 0
|
||||
|
@ -860,7 +905,7 @@ class _FuzzyPredicate:
|
|||
self.is_extension = is_extension
|
||||
if self.predicate not in self.operators:
|
||||
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
||||
fuzz = self.predicate[len("FUZZY"):] # number after prefix
|
||||
fuzz = self.predicate[len("FUZZY"):] # number after prefix
|
||||
self.fuzzy = int(fuzz) if fuzz else -1
|
||||
self.fuzzy_compare = fuzzy_compare
|
||||
self.key = _predicate_cache_key(self.attr, self.predicate, value, fuzzy=self.fuzzy)
|
||||
|
@ -1082,7 +1127,7 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
|
|||
elif cls == _FuzzyPredicate:
|
||||
if isinstance(value, dict):
|
||||
# add predicates inside fuzzy operator
|
||||
fuzz = type_[len("FUZZY"):] # number after prefix
|
||||
fuzz = type_[len("FUZZY"):] # number after prefix
|
||||
fuzzy_val = int(fuzz) if fuzz else -1
|
||||
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
|
||||
extra_predicates, seen_predicates,
|
||||
|
@ -1101,8 +1146,9 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
|
|||
return output
|
||||
|
||||
|
||||
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
|
||||
seen_predicates):
|
||||
def _get_extension_extra_predicates(
|
||||
spec, extra_predicates, predicate_types, seen_predicates
|
||||
):
|
||||
output = []
|
||||
for attr, value in spec.items():
|
||||
if isinstance(value, dict):
|
||||
|
@ -1131,7 +1177,7 @@ def _get_operators(spec):
|
|||
return (ONE,)
|
||||
elif spec["OP"] in lookup:
|
||||
return lookup[spec["OP"]]
|
||||
#Min_max {n,m}
|
||||
# Min_max {n,m}
|
||||
elif spec["OP"].startswith("{") and spec["OP"].endswith("}"):
|
||||
# {n} --> {n,n} exactly n ONE,(n)
|
||||
# {n,m}--> {n,m} min of n, max of m ONE,(n),ZERO_ONE,(m)
|
||||
|
@ -1142,8 +1188,8 @@ def _get_operators(spec):
|
|||
min_max = min_max if "," in min_max else f"{min_max},{min_max}"
|
||||
n, m = min_max.split(",")
|
||||
|
||||
#1. Either n or m is a blank string and the other is numeric -->isdigit
|
||||
#2. Both are numeric and n <= m
|
||||
# 1. Either n or m is a blank string and the other is numeric -->isdigit
|
||||
# 2. Both are numeric and n <= m
|
||||
if (not n.isdecimal() and not m.isdecimal()) or (n.isdecimal() and m.isdecimal() and int(n) > int(m)):
|
||||
keys = ", ".join(lookup.keys()) + ", {n}, {n,m}, {n,}, {,m} where n and m are integers and n <= m "
|
||||
raise ValueError(Errors.E011.format(op=spec["OP"], opts=keys))
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
from libc.stdint cimport uintptr_t
|
||||
# cython: infer_types=True
|
||||
from preshed.maps cimport map_clear, map_get, map_init, map_iter, map_set
|
||||
|
||||
import warnings
|
||||
|
||||
from ..attrs cimport DEP, LEMMA, MORPH, ORTH, POS, TAG
|
||||
from ..attrs cimport DEP, LEMMA, MORPH, POS, TAG
|
||||
|
||||
from ..attrs import IDS
|
||||
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.span cimport Span
|
||||
from ..tokens.token cimport Token
|
||||
from ..typedefs cimport attr_t
|
||||
|
|
|
@ -40,11 +40,16 @@ cdef ActivationsC alloc_activations(SizesC n) nogil
|
|||
|
||||
cdef void free_activations(const ActivationsC* A) nogil
|
||||
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
||||
const WeightsC* W, SizesC n) nogil
|
||||
|
||||
cdef void predict_states(
|
||||
CBlas cblas, ActivationsC* A, StateC** states, const WeightsC* W, SizesC n
|
||||
) nogil
|
||||
|
||||
cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) nogil
|
||||
|
||||
cdef void cpu_log_loss(float* d_scores,
|
||||
const float* costs, const int* is_valid, const float* scores, int O) nogil
|
||||
|
||||
cdef void cpu_log_loss(
|
||||
float* d_scores,
|
||||
const float* costs,
|
||||
const int* is_valid,
|
||||
const float* scores,
|
||||
int O
|
||||
) nogil
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
# cython: profile=False
|
||||
cimport numpy as np
|
||||
from libc.math cimport exp
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
|
@ -8,13 +9,13 @@ from thinc.backends.linalg cimport Vec, VecVec
|
|||
|
||||
import numpy
|
||||
import numpy.random
|
||||
from thinc.api import CupyOps, Model, NumpyOps, get_ops
|
||||
from thinc.api import CupyOps, Model, NumpyOps
|
||||
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
|
||||
from ..pipeline._parser_internals.stateclass cimport StateClass
|
||||
from ..typedefs cimport class_t, hash_t, weight_t
|
||||
from ..typedefs cimport weight_t
|
||||
|
||||
|
||||
cdef WeightsC get_c_weights(model) except *:
|
||||
|
@ -78,33 +79,48 @@ cdef void resize_activations(ActivationsC* A, SizesC n) nogil:
|
|||
A.is_valid = <int*>calloc(n.states * n.classes, sizeof(A.is_valid[0]))
|
||||
A._max_size = n.states
|
||||
else:
|
||||
A.token_ids = <int*>realloc(A.token_ids,
|
||||
n.states * n.feats * sizeof(A.token_ids[0]))
|
||||
A.scores = <float*>realloc(A.scores,
|
||||
n.states * n.classes * sizeof(A.scores[0]))
|
||||
A.unmaxed = <float*>realloc(A.unmaxed,
|
||||
n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0]))
|
||||
A.hiddens = <float*>realloc(A.hiddens,
|
||||
n.states * n.hiddens * sizeof(A.hiddens[0]))
|
||||
A.is_valid = <int*>realloc(A.is_valid,
|
||||
n.states * n.classes * sizeof(A.is_valid[0]))
|
||||
A.token_ids = <int*>realloc(
|
||||
A.token_ids, n.states * n.feats * sizeof(A.token_ids[0])
|
||||
)
|
||||
A.scores = <float*>realloc(
|
||||
A.scores, n.states * n.classes * sizeof(A.scores[0])
|
||||
)
|
||||
A.unmaxed = <float*>realloc(
|
||||
A.unmaxed, n.states * n.hiddens * n.pieces * sizeof(A.unmaxed[0])
|
||||
)
|
||||
A.hiddens = <float*>realloc(
|
||||
A.hiddens, n.states * n.hiddens * sizeof(A.hiddens[0])
|
||||
)
|
||||
A.is_valid = <int*>realloc(
|
||||
A.is_valid, n.states * n.classes * sizeof(A.is_valid[0])
|
||||
)
|
||||
A._max_size = n.states
|
||||
A._curr_size = n.states
|
||||
|
||||
|
||||
cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
||||
const WeightsC* W, SizesC n) nogil:
|
||||
cdef double one = 1.0
|
||||
cdef void predict_states(
|
||||
CBlas cblas, ActivationsC* A, StateC** states, const WeightsC* W, SizesC n
|
||||
) nogil:
|
||||
resize_activations(A, n)
|
||||
for i in range(n.states):
|
||||
states[i].set_context_tokens(&A.token_ids[i*n.feats], n.feats)
|
||||
memset(A.unmaxed, 0, n.states * n.hiddens * n.pieces * sizeof(float))
|
||||
memset(A.hiddens, 0, n.states * n.hiddens * sizeof(float))
|
||||
sum_state_features(cblas, A.unmaxed,
|
||||
W.feat_weights, A.token_ids, n.states, n.feats, n.hiddens * n.pieces)
|
||||
sum_state_features(
|
||||
cblas,
|
||||
A.unmaxed,
|
||||
W.feat_weights,
|
||||
A.token_ids,
|
||||
n.states,
|
||||
n.feats,
|
||||
n.hiddens * n.pieces
|
||||
)
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.unmaxed[i*n.hiddens*n.pieces],
|
||||
W.feat_bias, 1., n.hiddens * n.pieces)
|
||||
VecVec.add_i(
|
||||
&A.unmaxed[i*n.hiddens*n.pieces],
|
||||
W.feat_bias, 1.,
|
||||
n.hiddens * n.pieces
|
||||
)
|
||||
for j in range(n.hiddens):
|
||||
index = i * n.hiddens * n.pieces + j * n.pieces
|
||||
which = Vec.arg_max(&A.unmaxed[index], n.pieces)
|
||||
|
@ -114,14 +130,15 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
|||
memcpy(A.scores, A.hiddens, n.states * n.classes * sizeof(float))
|
||||
else:
|
||||
# Compute hidden-to-output
|
||||
sgemm(cblas)(False, True, n.states, n.classes, n.hiddens,
|
||||
sgemm(cblas)(
|
||||
False, True, n.states, n.classes, n.hiddens,
|
||||
1.0, <const float *>A.hiddens, n.hiddens,
|
||||
<const float *>W.hidden_weights, n.hiddens,
|
||||
0.0, A.scores, n.classes)
|
||||
0.0, A.scores, n.classes
|
||||
)
|
||||
# Add bias
|
||||
for i in range(n.states):
|
||||
VecVec.add_i(&A.scores[i*n.classes],
|
||||
W.hidden_bias, 1., n.classes)
|
||||
VecVec.add_i(&A.scores[i*n.classes], W.hidden_bias, 1., n.classes)
|
||||
# Set unseen classes to minimum value
|
||||
i = 0
|
||||
min_ = A.scores[0]
|
||||
|
@ -134,9 +151,16 @@ cdef void predict_states(CBlas cblas, ActivationsC* A, StateC** states,
|
|||
A.scores[i*n.classes+j] = min_
|
||||
|
||||
|
||||
cdef void sum_state_features(CBlas cblas, float* output,
|
||||
const float* cached, const int* token_ids, int B, int F, int O) nogil:
|
||||
cdef int idx, b, f, i
|
||||
cdef void sum_state_features(
|
||||
CBlas cblas,
|
||||
float* output,
|
||||
const float* cached,
|
||||
const int* token_ids,
|
||||
int B,
|
||||
int F,
|
||||
int O
|
||||
) nogil:
|
||||
cdef int idx, b, f
|
||||
cdef const float* feature
|
||||
padding = cached
|
||||
cached += F * O
|
||||
|
@ -153,9 +177,13 @@ cdef void sum_state_features(CBlas cblas, float* output,
|
|||
token_ids += F
|
||||
|
||||
|
||||
cdef void cpu_log_loss(float* d_scores,
|
||||
const float* costs, const int* is_valid, const float* scores,
|
||||
int O) nogil:
|
||||
cdef void cpu_log_loss(
|
||||
float* d_scores,
|
||||
const float* costs,
|
||||
const int* is_valid,
|
||||
const float* scores,
|
||||
int O
|
||||
) nogil:
|
||||
"""Do multi-label log loss"""
|
||||
cdef double max_, gmax, Z, gZ
|
||||
best = arg_max_if_gold(scores, costs, is_valid, O)
|
||||
|
@ -179,8 +207,9 @@ cdef void cpu_log_loss(float* d_scores,
|
|||
d_scores[i] = exp(scores[i]-max_) / Z
|
||||
|
||||
|
||||
cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs,
|
||||
const int* is_valid, int n) nogil:
|
||||
cdef int arg_max_if_gold(
|
||||
const weight_t* scores, const weight_t* costs, const int* is_valid, int n
|
||||
) nogil:
|
||||
# Find minimum cost
|
||||
cdef float cost = 1
|
||||
for i in range(n):
|
||||
|
@ -204,10 +233,17 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
|
|||
return best
|
||||
|
||||
|
||||
|
||||
class ParserStepModel(Model):
|
||||
def __init__(self, docs, layers, *, has_upper, unseen_classes=None, train=True,
|
||||
dropout=0.1):
|
||||
def __init__(
|
||||
self,
|
||||
docs,
|
||||
layers,
|
||||
*,
|
||||
has_upper,
|
||||
unseen_classes=None,
|
||||
train=True,
|
||||
dropout=0.1
|
||||
):
|
||||
Model.__init__(self, name="parser_step_model", forward=step_forward)
|
||||
self.attrs["has_upper"] = has_upper
|
||||
self.attrs["dropout_rate"] = dropout
|
||||
|
@ -268,8 +304,10 @@ class ParserStepModel(Model):
|
|||
return ids
|
||||
|
||||
def backprop_step(self, token_ids, d_vector, get_d_tokvecs):
|
||||
if isinstance(self.state2vec.ops, CupyOps) \
|
||||
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray):
|
||||
if (
|
||||
isinstance(self.state2vec.ops, CupyOps)
|
||||
and not isinstance(token_ids, self.state2vec.ops.xp.ndarray)
|
||||
):
|
||||
# Move token_ids and d_vector to GPU, asynchronously
|
||||
self.backprops.append((
|
||||
util.get_async(self.cuda_stream, token_ids),
|
||||
|
@ -279,7 +317,6 @@ class ParserStepModel(Model):
|
|||
else:
|
||||
self.backprops.append((token_ids, d_vector, get_d_tokvecs))
|
||||
|
||||
|
||||
def finish_steps(self, golds):
|
||||
# Add a padding vector to the d_tokvecs gradient, so that missing
|
||||
# values don't affect the real gradient.
|
||||
|
@ -292,14 +329,15 @@ class ParserStepModel(Model):
|
|||
ids = ids.flatten()
|
||||
d_state_features = d_state_features.reshape(
|
||||
(ids.size, d_state_features.shape[2]))
|
||||
self.ops.scatter_add(d_tokvecs, ids,
|
||||
d_state_features)
|
||||
self.ops.scatter_add(d_tokvecs, ids, d_state_features)
|
||||
# Padded -- see update()
|
||||
self.bp_tokvecs(d_tokvecs[:-1])
|
||||
return d_tokvecs
|
||||
|
||||
|
||||
NUMPY_OPS = NumpyOps()
|
||||
|
||||
|
||||
def step_forward(model: ParserStepModel, states, is_train):
|
||||
token_ids = model.get_token_ids(states)
|
||||
vector, get_d_tokvecs = model.state2vec(token_ids, is_train)
|
||||
|
@ -312,7 +350,7 @@ def step_forward(model: ParserStepModel, states, is_train):
|
|||
scores, get_d_vector = model.vec2scores(vector, is_train)
|
||||
else:
|
||||
scores = NumpyOps().asarray(vector)
|
||||
get_d_vector = lambda d_scores: d_scores
|
||||
get_d_vector = lambda d_scores: d_scores # no-cython-lint: E731
|
||||
# If the class is unseen, make sure its score is minimum
|
||||
scores[:, model._class_mask == 0] = numpy.nanmin(scores)
|
||||
|
||||
|
@ -448,9 +486,11 @@ cdef class precompute_hiddens:
|
|||
|
||||
feat_weights = self.get_feat_weights()
|
||||
cdef int[:, ::1] ids = token_ids
|
||||
sum_state_features(cblas, <float*>state_vector.data,
|
||||
feat_weights, &ids[0,0],
|
||||
token_ids.shape[0], self.nF, self.nO*self.nP)
|
||||
sum_state_features(
|
||||
cblas, <float*>state_vector.data,
|
||||
feat_weights, &ids[0, 0],
|
||||
token_ids.shape[0], self.nF, self.nO*self.nP
|
||||
)
|
||||
state_vector += self.bias
|
||||
state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
|
||||
|
||||
|
@ -475,7 +515,7 @@ cdef class precompute_hiddens:
|
|||
|
||||
def backprop_maxout(d_best):
|
||||
return self.ops.backprop_maxout(d_best, mask, self.nP)
|
||||
|
||||
|
||||
return state_vector, backprop_maxout
|
||||
|
||||
def _relu_nonlinearity(self, state_vector):
|
||||
|
@ -489,5 +529,5 @@ cdef class precompute_hiddens:
|
|||
def backprop_relu(d_best):
|
||||
d_best *= mask
|
||||
return d_best.reshape((d_best.shape + (1,)))
|
||||
|
||||
|
||||
return state_vector, backprop_relu
|
||||
|
|
|
@ -9,7 +9,7 @@ from thinc.util import partial
|
|||
from ..attrs import ORTH
|
||||
from ..errors import Errors, Warnings
|
||||
from ..tokens import Doc
|
||||
from ..vectors import Mode
|
||||
from ..vectors import Mode, Vectors
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
||||
|
@ -48,11 +48,14 @@ def forward(
|
|||
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
if vocab.vectors.mode == Mode.default:
|
||||
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||
V = model.ops.asarray(vocab.vectors.data)
|
||||
rows = vocab.vectors.find(keys=keys)
|
||||
V = model.ops.as_contig(V[rows])
|
||||
elif vocab.vectors.mode == Mode.floret:
|
||||
elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret:
|
||||
V = vocab.vectors.get_batch(keys)
|
||||
V = model.ops.as_contig(V)
|
||||
elif hasattr(vocab.vectors, "get_batch"):
|
||||
V = vocab.vectors.get_batch(keys)
|
||||
V = model.ops.as_contig(V)
|
||||
else:
|
||||
|
@ -61,7 +64,7 @@ def forward(
|
|||
vectors_data = model.ops.gemm(V, W, trans2=True)
|
||||
except ValueError:
|
||||
raise RuntimeError(Errors.E896)
|
||||
if vocab.vectors.mode == Mode.default:
|
||||
if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default:
|
||||
# Convert negative indices to 0-vectors
|
||||
# TODO: more options for UNK tokens
|
||||
vectors_data[rows < 0] = 0
|
||||
|
|
|
@ -11,7 +11,7 @@ from .typedefs cimport attr_t, hash_t
|
|||
cdef class Morphology:
|
||||
cdef readonly Pool mem
|
||||
cdef readonly StringStore strings
|
||||
cdef PreshMap tags # Keyed by hash, value is pointer to tag
|
||||
cdef PreshMap tags # Keyed by hash, value is pointer to tag
|
||||
|
||||
cdef MorphAnalysisC create_morph_tag(self, field_feature_pairs) except *
|
||||
cdef int insert(self, MorphAnalysisC tag) except -1
|
||||
|
@ -20,4 +20,8 @@ cdef class Morphology:
|
|||
cdef int check_feature(const MorphAnalysisC* morph, attr_t feature) nogil
|
||||
cdef list list_features(const MorphAnalysisC* morph)
|
||||
cdef np.ndarray get_by_field(const MorphAnalysisC* morph, attr_t field)
|
||||
cdef int get_n_by_field(attr_t* results, const MorphAnalysisC* morph, attr_t field) nogil
|
||||
cdef int get_n_by_field(
|
||||
attr_t* results,
|
||||
const MorphAnalysisC* morph,
|
||||
attr_t field,
|
||||
) nogil
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types
|
||||
# cython: profile=False
|
||||
import warnings
|
||||
|
||||
import numpy
|
||||
|
@ -83,10 +84,11 @@ cdef class Morphology:
|
|||
features = self.normalize_attrs(features)
|
||||
string_features = {self.strings.as_string(field): self.strings.as_string(values) for field, values in features.items()}
|
||||
# normalized UFEATS string with sorted fields and values
|
||||
norm_feats_string = self.FEATURE_SEP.join(sorted([
|
||||
self.FIELD_SEP.join([field, values])
|
||||
for field, values in string_features.items()
|
||||
]))
|
||||
norm_feats_string = self.FEATURE_SEP.join(
|
||||
sorted(
|
||||
[self.FIELD_SEP.join([field, values]) for field, values in string_features.items()]
|
||||
)
|
||||
)
|
||||
return norm_feats_string or self.EMPTY_MORPH
|
||||
|
||||
def normalize_attrs(self, attrs):
|
||||
|
@ -192,6 +194,7 @@ cdef int get_n_by_field(attr_t* results, const MorphAnalysisC* morph, attr_t fie
|
|||
n_results += 1
|
||||
return n_results
|
||||
|
||||
|
||||
def unpickle_morphology(strings, tags):
|
||||
cdef Morphology morphology = Morphology(strings)
|
||||
for tag in tags:
|
||||
|
|
|
@ -8,7 +8,7 @@ cpdef enum univ_pos_t:
|
|||
ADV
|
||||
AUX
|
||||
CONJ
|
||||
CCONJ # U20
|
||||
CCONJ # U20
|
||||
DET
|
||||
INTJ
|
||||
NOUN
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
|
||||
# cython: profile=False
|
||||
IDS = {
|
||||
"": NO_TAG,
|
||||
"ADJ": ADJ,
|
||||
|
|
|
@ -46,11 +46,18 @@ cdef struct EditTreeC:
|
|||
bint is_match_node
|
||||
NodeC inner
|
||||
|
||||
cdef inline EditTreeC edittree_new_match(len_t prefix_len, len_t suffix_len,
|
||||
uint32_t prefix_tree, uint32_t suffix_tree):
|
||||
cdef MatchNodeC match_node = MatchNodeC(prefix_len=prefix_len,
|
||||
suffix_len=suffix_len, prefix_tree=prefix_tree,
|
||||
suffix_tree=suffix_tree)
|
||||
cdef inline EditTreeC edittree_new_match(
|
||||
len_t prefix_len,
|
||||
len_t suffix_len,
|
||||
uint32_t prefix_tree,
|
||||
uint32_t suffix_tree
|
||||
):
|
||||
cdef MatchNodeC match_node = MatchNodeC(
|
||||
prefix_len=prefix_len,
|
||||
suffix_len=suffix_len,
|
||||
prefix_tree=prefix_tree,
|
||||
suffix_tree=suffix_tree
|
||||
)
|
||||
cdef NodeC inner = NodeC(match_node=match_node)
|
||||
return EditTreeC(is_match_node=True, inner=inner)
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
# cython: infer_types=True, binding=True
|
||||
# cython: profile=False
|
||||
from cython.operator cimport dereference as deref
|
||||
from libc.stdint cimport UINT32_MAX, uint32_t
|
||||
from libc.string cimport memset
|
||||
from libcpp.pair cimport pair
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ...typedefs cimport hash_t
|
||||
|
||||
from ... import util
|
||||
|
@ -25,17 +24,16 @@ cdef LCS find_lcs(str source, str target):
|
|||
target (str): The second string.
|
||||
RETURNS (LCS): The spans of the longest common subsequences.
|
||||
"""
|
||||
cdef Py_ssize_t source_len = len(source)
|
||||
cdef Py_ssize_t target_len = len(target)
|
||||
cdef size_t longest_align = 0;
|
||||
cdef size_t longest_align = 0
|
||||
cdef int source_idx, target_idx
|
||||
cdef LCS lcs
|
||||
cdef Py_UCS4 source_cp, target_cp
|
||||
|
||||
memset(&lcs, 0, sizeof(lcs))
|
||||
|
||||
cdef vector[size_t] prev_aligns = vector[size_t](target_len);
|
||||
cdef vector[size_t] cur_aligns = vector[size_t](target_len);
|
||||
cdef vector[size_t] prev_aligns = vector[size_t](target_len)
|
||||
cdef vector[size_t] cur_aligns = vector[size_t](target_len)
|
||||
|
||||
for (source_idx, source_cp) in enumerate(source):
|
||||
for (target_idx, target_cp) in enumerate(target):
|
||||
|
@ -89,7 +87,7 @@ cdef class EditTrees:
|
|||
cdef LCS lcs = find_lcs(form, lemma)
|
||||
|
||||
cdef EditTreeC tree
|
||||
cdef uint32_t tree_id, prefix_tree, suffix_tree
|
||||
cdef uint32_t prefix_tree, suffix_tree
|
||||
if lcs_is_empty(lcs):
|
||||
tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma))
|
||||
else:
|
||||
|
@ -108,7 +106,7 @@ cdef class EditTrees:
|
|||
return self._tree_id(tree)
|
||||
|
||||
cdef uint32_t _tree_id(self, EditTreeC tree):
|
||||
# If this tree has been constructed before, return its identifier.
|
||||
# If this tree has been constructed before, return its identifier.
|
||||
cdef hash_t hash = edittree_hash(tree)
|
||||
cdef unordered_map[hash_t, uint32_t].iterator iter = self.map.find(hash)
|
||||
if iter != self.map.end():
|
||||
|
@ -289,6 +287,7 @@ def _tree2dict(tree):
|
|||
tree = tree["inner"]["subst_node"]
|
||||
return(dict(tree))
|
||||
|
||||
|
||||
def _dict2tree(tree):
|
||||
errors = validate_edit_tree(tree)
|
||||
if errors:
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic.types import StrictBool, StrictInt, StrictStr
|
||||
try:
|
||||
from pydantic.v1 import BaseModel, Field, ValidationError
|
||||
from pydantic.v1.types import StrictBool, StrictInt, StrictStr
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
||||
from pydantic.types import StrictBool, StrictInt, StrictStr # type: ignore
|
||||
|
||||
|
||||
class MatchNodeSchema(BaseModel):
|
||||
|
|
|
@ -1,17 +1,13 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
cimport numpy as np
|
||||
|
||||
import numpy
|
||||
|
||||
from cpython.ref cimport Py_XDECREF, PyObject
|
||||
from thinc.extra.search cimport Beam
|
||||
|
||||
from thinc.extra.search import MaxViolation
|
||||
|
||||
from thinc.extra.search cimport MaxViolation
|
||||
|
||||
from ...typedefs cimport class_t, hash_t
|
||||
from ...typedefs cimport class_t
|
||||
from .transition_system cimport Transition, TransitionSystem
|
||||
|
||||
from ...errors import Errors
|
||||
|
@ -146,7 +142,6 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
|||
cdef MaxViolation violn
|
||||
pbeam = BeamBatch(moves, states, golds, width=width, density=beam_density)
|
||||
gbeam = BeamBatch(moves, states, golds, width=width, density=0.0)
|
||||
cdef StateClass state
|
||||
beam_maps = []
|
||||
backprops = []
|
||||
violns = [MaxViolation() for _ in range(len(states))]
|
||||
|
|
|
@ -277,7 +277,6 @@ cdef cppclass StateC:
|
|||
|
||||
return n
|
||||
|
||||
|
||||
int n_L(int head) nogil const:
|
||||
return n_arcs(this._left_arcs, head)
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# cython: profile=False
|
|
@ -1,4 +1,4 @@
|
|||
# cython: profile=True, cdivision=True, infer_types=True
|
||||
# cython: cdivision=True, infer_types=True
|
||||
from cymem.cymem cimport Address, Pool
|
||||
from libc.stdint cimport int32_t
|
||||
from libcpp.vector cimport vector
|
||||
|
@ -9,7 +9,7 @@ from ...strings cimport hash_string
|
|||
from ...structs cimport TokenC
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...tokens.token cimport MISSING_DEP
|
||||
from ...typedefs cimport attr_t, hash_t
|
||||
from ...typedefs cimport attr_t
|
||||
|
||||
from ...training import split_bilu_label
|
||||
|
||||
|
@ -68,8 +68,9 @@ cdef struct GoldParseStateC:
|
|||
weight_t pop_cost
|
||||
|
||||
|
||||
cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
|
||||
heads, labels, sent_starts) except *:
|
||||
cdef GoldParseStateC create_gold_state(
|
||||
Pool mem, const StateC* state, heads, labels, sent_starts
|
||||
) except *:
|
||||
cdef GoldParseStateC gs
|
||||
gs.length = len(heads)
|
||||
gs.stride = 1
|
||||
|
@ -82,7 +83,7 @@ cdef GoldParseStateC create_gold_state(Pool mem, const StateC* state,
|
|||
gs.n_kids_in_stack = <int32_t*>mem.alloc(gs.length, sizeof(gs.n_kids_in_stack[0]))
|
||||
|
||||
for i, is_sent_start in enumerate(sent_starts):
|
||||
if is_sent_start == True:
|
||||
if is_sent_start is True:
|
||||
gs.state_bits[i] = set_state_flag(
|
||||
gs.state_bits[i],
|
||||
IS_SENT_START,
|
||||
|
@ -210,6 +211,7 @@ cdef class ArcEagerGold:
|
|||
def update(self, StateClass stcls):
|
||||
update_gold_state(&self.c, stcls.c)
|
||||
|
||||
|
||||
def _get_aligned_sent_starts(example):
|
||||
"""Get list of SENT_START attributes aligned to the predicted tokenization.
|
||||
If the reference has not sentence starts, return a list of None values.
|
||||
|
@ -524,7 +526,6 @@ cdef class Break:
|
|||
"""
|
||||
@staticmethod
|
||||
cdef bint is_valid(const StateC* st, attr_t label) nogil:
|
||||
cdef int i
|
||||
if st.buffer_length() < 2:
|
||||
return False
|
||||
elif st.B(1) != st.B(0) + 1:
|
||||
|
@ -556,8 +557,8 @@ cdef class Break:
|
|||
cost -= 1
|
||||
if gold.heads[si] == b0:
|
||||
cost -= 1
|
||||
if not is_sent_start(gold, state.B(1)) \
|
||||
and not is_sent_start_unknown(gold, state.B(1)):
|
||||
if not is_sent_start(gold, state.B(1)) and\
|
||||
not is_sent_start_unknown(gold, state.B(1)):
|
||||
cost += 1
|
||||
return cost
|
||||
|
||||
|
@ -803,7 +804,6 @@ cdef class ArcEager(TransitionSystem):
|
|||
raise TypeError(Errors.E909.format(name="ArcEagerGold"))
|
||||
cdef ArcEagerGold gold_ = gold
|
||||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||
else:
|
||||
|
@ -875,7 +875,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
print("Gold")
|
||||
for token in example.y:
|
||||
print(token.i, token.text, token.dep_, token.head.text)
|
||||
aligned_heads, aligned_labels = example.get_aligned_parse()
|
||||
aligned_heads, _aligned_labels = example.get_aligned_parse()
|
||||
print("Aligned heads")
|
||||
for i, head in enumerate(aligned_heads):
|
||||
print(example.x[i], example.x[head] if head is not None else "__")
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
# cython: profile=False
|
||||
from cymem.cymem cimport Pool
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
|
@ -14,7 +12,7 @@ from ...tokens.span import Span
|
|||
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...structs cimport SpanC, TokenC
|
||||
from ...structs cimport SpanC
|
||||
from ...tokens.span cimport Span
|
||||
from ...typedefs cimport attr_t, weight_t
|
||||
|
||||
|
@ -141,11 +139,10 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
OUT: Counter()
|
||||
}
|
||||
actions[OUT][''] = 1 # Represents a token predicted to be outside of any entity
|
||||
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
|
||||
actions[UNIT][''] = 1 # Represents a token prohibited to be in an entity
|
||||
for entity_type in kwargs.get('entity_types', []):
|
||||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][entity_type] = 1
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for example in kwargs.get('examples', []):
|
||||
for token in example.y:
|
||||
ent_type = token.ent_type_
|
||||
|
@ -164,7 +161,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
if token.ent_type:
|
||||
labels.add(token.ent_type_)
|
||||
return labels
|
||||
|
||||
|
||||
def move_name(self, int move, attr_t label):
|
||||
if move == OUT:
|
||||
return 'O'
|
||||
|
@ -325,7 +322,6 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
raise TypeError(Errors.E909.format(name="BiluoGold"))
|
||||
cdef BiluoGold gold_ = gold
|
||||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls.c, &gold_state, self.c[i].label)
|
||||
else:
|
||||
|
@ -486,10 +482,8 @@ cdef class In:
|
|||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
move = IN
|
||||
cdef int next_act = gold.ner[s.B(1)].move if s.B(1) >= 0 else OUT
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||
|
||||
if g_act == MISSING:
|
||||
|
@ -549,12 +543,10 @@ cdef class Last:
|
|||
@staticmethod
|
||||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
move = LAST
|
||||
b0 = s.B(0)
|
||||
ent_start = s.E(0)
|
||||
|
||||
cdef int g_act = gold.ner[b0].move
|
||||
cdef attr_t g_tag = gold.ner[b0].label
|
||||
|
||||
cdef int cost = 0
|
||||
|
||||
|
@ -650,7 +642,6 @@ cdef class Unit:
|
|||
cost += 1
|
||||
break
|
||||
return cost
|
||||
|
||||
|
||||
|
||||
cdef class Out:
|
||||
|
@ -675,7 +666,6 @@ cdef class Out:
|
|||
cdef weight_t cost(const StateC* s, const void* _gold, attr_t label) nogil:
|
||||
gold = <GoldNERStateC*>_gold
|
||||
cdef int g_act = gold.ner[s.B(0)].move
|
||||
cdef attr_t g_tag = gold.ner[s.B(0)].label
|
||||
cdef weight_t cost = 0
|
||||
if g_act == MISSING:
|
||||
pass
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: profile=True, infer_types=True
|
||||
# cython: infer_types=True
|
||||
"""Implements the projectivize/deprojectivize mechanism in Nivre & Nilsson 2005
|
||||
for doing pseudo-projective parsing implementation uses the HEAD decoration
|
||||
scheme.
|
||||
|
@ -125,14 +125,17 @@ def decompose(label):
|
|||
def is_decorated(label):
|
||||
return DELIMITER in label
|
||||
|
||||
|
||||
def count_decorated_labels(gold_data):
|
||||
freqs = {}
|
||||
for example in gold_data:
|
||||
proj_heads, deco_deps = projectivize(example.get_aligned("HEAD"),
|
||||
example.get_aligned("DEP"))
|
||||
# set the label to ROOT for each root dependent
|
||||
deco_deps = ['ROOT' if head == i else deco_deps[i]
|
||||
for i, head in enumerate(proj_heads)]
|
||||
deco_deps = [
|
||||
'ROOT' if head == i else deco_deps[i]
|
||||
for i, head in enumerate(proj_heads)
|
||||
]
|
||||
# count label frequencies
|
||||
for label in deco_deps:
|
||||
if is_decorated(label):
|
||||
|
@ -160,9 +163,9 @@ def projectivize(heads, labels):
|
|||
|
||||
|
||||
cdef vector[int] _heads_to_c(heads):
|
||||
cdef vector[int] c_heads;
|
||||
cdef vector[int] c_heads
|
||||
for head in heads:
|
||||
if head == None:
|
||||
if head is None:
|
||||
c_heads.push_back(-1)
|
||||
else:
|
||||
assert head < len(heads)
|
||||
|
@ -199,6 +202,7 @@ def _decorate(heads, proj_heads, labels):
|
|||
deco_labels.append(labels[tokenid])
|
||||
return deco_labels
|
||||
|
||||
|
||||
def get_smallest_nonproj_arc_slow(heads):
|
||||
cdef vector[int] c_heads = _heads_to_c(heads)
|
||||
return _get_smallest_nonproj_arc(c_heads)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# cython: infer_types=True
|
||||
import numpy
|
||||
|
||||
# cython: profile=False
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
from ...tokens.doc cimport Doc
|
||||
|
@ -38,11 +37,11 @@ cdef class StateClass:
|
|||
cdef vector[ArcC] arcs
|
||||
self.c.get_arcs(&arcs)
|
||||
return list(arcs)
|
||||
#py_arcs = []
|
||||
#for arc in arcs:
|
||||
# if arc.head != -1 and arc.child != -1:
|
||||
# py_arcs.append((arc.head, arc.child, arc.label))
|
||||
#return arcs
|
||||
# py_arcs = []
|
||||
# for arc in arcs:
|
||||
# if arc.head != -1 and arc.child != -1:
|
||||
# py_arcs.append((arc.head, arc.child, arc.label))
|
||||
# return arcs
|
||||
|
||||
def add_arc(self, int head, int child, int label):
|
||||
self.c.add_arc(head, child, label)
|
||||
|
@ -52,10 +51,10 @@ cdef class StateClass:
|
|||
|
||||
def H(self, int child):
|
||||
return self.c.H(child)
|
||||
|
||||
|
||||
def L(self, int head, int idx):
|
||||
return self.c.L(head, idx)
|
||||
|
||||
|
||||
def R(self, int head, int idx):
|
||||
return self.c.R(head, idx)
|
||||
|
||||
|
@ -98,7 +97,7 @@ cdef class StateClass:
|
|||
|
||||
def H(self, int i):
|
||||
return self.c.H(i)
|
||||
|
||||
|
||||
def E(self, int i):
|
||||
return self.c.E(i)
|
||||
|
||||
|
@ -116,7 +115,7 @@ cdef class StateClass:
|
|||
|
||||
def H_(self, int i):
|
||||
return self.doc[self.c.H(i)]
|
||||
|
||||
|
||||
def E_(self, int i):
|
||||
return self.doc[self.c.E(i)]
|
||||
|
||||
|
@ -125,7 +124,7 @@ cdef class StateClass:
|
|||
|
||||
def R_(self, int i, int idx):
|
||||
return self.doc[self.c.R(i, idx)]
|
||||
|
||||
|
||||
def empty(self):
|
||||
return self.c.empty()
|
||||
|
||||
|
@ -134,7 +133,7 @@ cdef class StateClass:
|
|||
|
||||
def at_break(self):
|
||||
return False
|
||||
#return self.c.at_break()
|
||||
# return self.c.at_break()
|
||||
|
||||
def has_head(self, int i):
|
||||
return self.c.has_head(i)
|
||||
|
|
|
@ -20,11 +20,15 @@ cdef struct Transition:
|
|||
int (*do)(StateC* state, attr_t label) nogil
|
||||
|
||||
|
||||
ctypedef weight_t (*get_cost_func_t)(const StateC* state, const void* gold,
|
||||
attr_tlabel) nogil
|
||||
ctypedef weight_t (*move_cost_func_t)(const StateC* state, const void* gold) nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(const StateC* state, const void*
|
||||
gold, attr_t label) nogil
|
||||
ctypedef weight_t (*get_cost_func_t)(
|
||||
const StateC* state, const void* gold, attr_tlabel
|
||||
) nogil
|
||||
ctypedef weight_t (*move_cost_func_t)(
|
||||
const StateC* state, const void* gold
|
||||
) nogil
|
||||
ctypedef weight_t (*label_cost_func_t)(
|
||||
const StateC* state, const void* gold, attr_t label
|
||||
) nogil
|
||||
|
||||
ctypedef int (*do_func_t)(StateC* state, attr_t label) nogil
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=False
|
||||
from __future__ import print_function
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
|
@ -8,9 +9,7 @@ from collections import Counter
|
|||
import srsly
|
||||
|
||||
from ...structs cimport TokenC
|
||||
from ...tokens.doc cimport Doc
|
||||
from ...typedefs cimport attr_t, weight_t
|
||||
from . cimport _beam_utils
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from ... import util
|
||||
|
@ -231,7 +230,6 @@ cdef class TransitionSystem:
|
|||
return self
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
transitions = []
|
||||
serializers = {
|
||||
'moves': lambda: srsly.json_dumps(self.labels),
|
||||
'strings': lambda: self.strings.to_bytes(),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Iterable, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from thinc.api import Config, Model
|
||||
|
||||
|
@ -124,6 +124,7 @@ def make_parser(
|
|||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"beam_parser",
|
||||
assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from itertools import islice
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import srsly
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||
|
||||
from ..morphology cimport Morphology
|
||||
|
@ -14,10 +13,8 @@ from ..errors import Errors
|
|||
from ..language import Language
|
||||
from ..parts_of_speech import IDS as POS_IDS
|
||||
from ..scorer import Scorer
|
||||
from ..symbols import POS
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ..util import registry
|
||||
from .pipe import deserialize_config
|
||||
from .tagger import Tagger
|
||||
|
||||
# See #9050
|
||||
|
@ -76,8 +73,11 @@ def morphologizer_score(examples, **kwargs):
|
|||
results = {}
|
||||
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||
results.update(Scorer.score_token_attr(examples, "morph", getter=morph_key_getter, **kwargs))
|
||||
results.update(Scorer.score_token_attr_per_feat(examples,
|
||||
"morph", getter=morph_key_getter, **kwargs))
|
||||
results.update(
|
||||
Scorer.score_token_attr_per_feat(
|
||||
examples, "morph", getter=morph_key_getter, **kwargs
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
|
@ -233,7 +233,6 @@ class Morphologizer(Tagger):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef Vocab vocab = self.vocab
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
cdef bint extend = self.cfg["extend"]
|
||||
labels = self.labels
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
from thinc.api import Config, CosineDistance, Model, set_dropout_rate, to_categorical
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
from ..attrs import ID, POS
|
||||
from ..attrs import ID
|
||||
from ..errors import Errors
|
||||
from ..language import Language
|
||||
from ..training import validate_examples
|
||||
from ._parser_internals import nonproj
|
||||
from .tagger import Tagger
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
||||
|
@ -103,10 +100,9 @@ class MultitaskObjective(Tagger):
|
|||
cdef int idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||
guesses = scores.argmax(axis=1)
|
||||
docs = [eg.predicted for eg in examples]
|
||||
for i, eg in enumerate(examples):
|
||||
# Handles alignment for tokenization differences
|
||||
doc_annots = eg.get_aligned() # TODO
|
||||
_doc_annots = eg.get_aligned() # TODO
|
||||
for j in range(len(eg.predicted)):
|
||||
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
||||
label = self.make_label(j, tok_annots)
|
||||
|
@ -206,7 +202,6 @@ class ClozeMultitask(TrainablePipe):
|
|||
losses[self.name] = 0.
|
||||
set_dropout_rate(self.model, drop)
|
||||
validate_examples(examples, "ClozeMultitask.rehearse")
|
||||
docs = [eg.predicted for eg in examples]
|
||||
predictions, bp_predictions = self.model.begin_update()
|
||||
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
|
||||
bp_predictions(d_predictions)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Iterable, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from thinc.api import Config, Model
|
||||
|
||||
|
@ -10,7 +10,7 @@ from ._parser_internals.ner cimport BiluoPushDown
|
|||
from .transition_parser cimport Parser
|
||||
|
||||
from ..language import Language
|
||||
from ..scorer import PRFScore, get_ner_prf
|
||||
from ..scorer import get_ner_prf
|
||||
from ..training import remove_bilu_prefix
|
||||
from ..util import registry
|
||||
|
||||
|
@ -100,6 +100,7 @@ def make_ner(
|
|||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"beam_ner",
|
||||
assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
import warnings
|
||||
from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, Iterator, Tuple, Union
|
||||
|
||||
import srsly
|
||||
|
||||
|
@ -40,7 +40,7 @@ cdef class Pipe:
|
|||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="Pipe", method="__call__", name=self.name))
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
@ -59,7 +59,7 @@ cdef class Pipe:
|
|||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
|
||||
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None):
|
||||
"""Initialize the pipe. For non-trainable components, this method
|
||||
is optional. For trainable components, which should inherit
|
||||
from the subclass TrainablePipe, the provided data examples
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import srsly
|
||||
|
@ -7,13 +7,13 @@ from ..tokens.doc cimport Doc
|
|||
|
||||
from .. import util
|
||||
from ..language import Language
|
||||
from ..scorer import Scorer
|
||||
from .pipe import Pipe
|
||||
from .senter import senter_score
|
||||
|
||||
# see #9050
|
||||
BACKWARD_OVERWRITE = False
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"sentencizer",
|
||||
assigns=["token.is_sent_start", "doc.sents"],
|
||||
|
@ -36,17 +36,19 @@ class Sentencizer(Pipe):
|
|||
DOCS: https://spacy.io/api/sentencizer
|
||||
"""
|
||||
|
||||
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
|
||||
'।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
|
||||
'᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
|
||||
'‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
|
||||
'꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
|
||||
'﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
|
||||
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
|
||||
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
|
||||
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
|
||||
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
|
||||
'。', '。']
|
||||
default_punct_chars = [
|
||||
'!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
|
||||
'।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
|
||||
'᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
|
||||
'‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
|
||||
'꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
|
||||
'﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
|
||||
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
|
||||
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
|
||||
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
|
||||
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
|
||||
'。', '。'
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -128,7 +130,6 @@ class Sentencizer(Pipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef int idx = 0
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
|
@ -169,7 +170,6 @@ class Sentencizer(Pipe):
|
|||
path = path.with_suffix(".json")
|
||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
|
||||
|
||||
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the sentencizer from disk.
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from itertools import islice
|
||||
from typing import Callable, Optional
|
||||
|
||||
import srsly
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
|
|
@ -1,26 +1,18 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
import warnings
|
||||
# cython: infer_types=True, binding=True
|
||||
from itertools import islice
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy
|
||||
import srsly
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, set_dropout_rate
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ..morphology cimport Morphology
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..vocab cimport Vocab
|
||||
|
||||
from .. import util
|
||||
from ..attrs import ID, POS
|
||||
from ..errors import Errors, Warnings
|
||||
from ..errors import Errors
|
||||
from ..language import Language
|
||||
from ..parts_of_speech import X
|
||||
from ..scorer import Scorer
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ..util import registry
|
||||
from .pipe import deserialize_config
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
||||
# See #9050
|
||||
|
@ -169,7 +161,6 @@ class Tagger(TrainablePipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef Vocab vocab = self.vocab
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
labels = self.labels
|
||||
for i, doc in enumerate(docs):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
# cython: infer_types=True, binding=True
|
||||
from typing import Callable, Dict, Iterable, Iterator, Optional, Tuple
|
||||
|
||||
import srsly
|
||||
|
@ -55,7 +55,7 @@ cdef class TrainablePipe(Pipe):
|
|||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
@ -102,9 +102,9 @@ cdef class TrainablePipe(Pipe):
|
|||
def update(self,
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float=0.0,
|
||||
sgd: Optimizer=None,
|
||||
losses: Optional[Dict[str, float]]=None) -> Dict[str, float]:
|
||||
drop: float = 0.0,
|
||||
sgd: Optimizer = None,
|
||||
losses: Optional[Dict[str, float]] = None) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
|
@ -138,8 +138,8 @@ cdef class TrainablePipe(Pipe):
|
|||
def rehearse(self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
sgd: Optimizer=None,
|
||||
losses: Dict[str, float]=None,
|
||||
sgd: Optimizer = None,
|
||||
losses: Dict[str, float] = None,
|
||||
**config) -> Dict[str, float]:
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
|
@ -177,7 +177,7 @@ cdef class TrainablePipe(Pipe):
|
|||
"""
|
||||
return util.create_default_optimizer()
|
||||
|
||||
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
|
||||
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None):
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
This method needs to be implemented by each TrainablePipe component,
|
||||
ensuring the internal model (if available) is initialized properly
|
||||
|
|
|
@ -13,8 +13,18 @@ cdef class Parser(TrainablePipe):
|
|||
cdef readonly TransitionSystem moves
|
||||
cdef public object _multitasks
|
||||
|
||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil
|
||||
cdef void _parseC(
|
||||
self,
|
||||
CBlas cblas,
|
||||
StateC** states,
|
||||
WeightsC weights,
|
||||
SizesC sizes
|
||||
) nogil
|
||||
|
||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil
|
||||
cdef void c_transition_batch(
|
||||
self,
|
||||
StateC** states,
|
||||
const float* scores,
|
||||
int nr_class,
|
||||
int batch_size
|
||||
) nogil
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False, binding=True
|
||||
# cython: profile=False
|
||||
from __future__ import print_function
|
||||
|
||||
cimport numpy as np
|
||||
|
@ -7,20 +8,15 @@ from cymem.cymem cimport Pool
|
|||
from itertools import islice
|
||||
|
||||
from libc.stdlib cimport calloc, free
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.string cimport memset
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
import random
|
||||
|
||||
import srsly
|
||||
from thinc.api import CupyOps, NumpyOps, get_ops, set_dropout_rate
|
||||
|
||||
from thinc.extra.search cimport Beam
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy
|
||||
import numpy.random
|
||||
import srsly
|
||||
from thinc.api import CupyOps, NumpyOps, set_dropout_rate
|
||||
|
||||
from ..ml.parser_model cimport (
|
||||
ActivationsC,
|
||||
|
@ -42,7 +38,7 @@ from .trainable_pipe import TrainablePipe
|
|||
from ._parser_internals cimport _beam_utils
|
||||
|
||||
from .. import util
|
||||
from ..errors import Errors, Warnings
|
||||
from ..errors import Errors
|
||||
from ..training import validate_examples, validate_get_examples
|
||||
from ._parser_internals import _beam_utils
|
||||
|
||||
|
@ -258,7 +254,6 @@ cdef class Parser(TrainablePipe):
|
|||
except Exception as e:
|
||||
error_handler(self.name, self, batch_in_order, e)
|
||||
|
||||
|
||||
def predict(self, docs):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -300,8 +295,6 @@ cdef class Parser(TrainablePipe):
|
|||
return batch
|
||||
|
||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
self._ensure_labels_are_added(docs)
|
||||
batch = _beam_utils.BeamBatch(
|
||||
self.moves,
|
||||
|
@ -321,16 +314,18 @@ cdef class Parser(TrainablePipe):
|
|||
del model
|
||||
return list(batch)
|
||||
|
||||
cdef void _parseC(self, CBlas cblas, StateC** states,
|
||||
WeightsC weights, SizesC sizes) nogil:
|
||||
cdef int i, j
|
||||
cdef void _parseC(
|
||||
self, CBlas cblas, StateC** states, WeightsC weights, SizesC sizes
|
||||
) nogil:
|
||||
cdef int i
|
||||
cdef vector[StateC*] unfinished
|
||||
cdef ActivationsC activations = alloc_activations(sizes)
|
||||
while sizes.states >= 1:
|
||||
predict_states(cblas, &activations, states, &weights, sizes)
|
||||
# Validate actions, argmax, take action.
|
||||
self.c_transition_batch(states,
|
||||
activations.scores, sizes.classes, sizes.states)
|
||||
self.c_transition_batch(
|
||||
states, activations.scores, sizes.classes, sizes.states
|
||||
)
|
||||
for i in range(sizes.states):
|
||||
if not states[i].is_final():
|
||||
unfinished.push_back(states[i])
|
||||
|
@ -342,7 +337,6 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
def set_annotations(self, docs, states_or_beams):
|
||||
cdef StateClass state
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
states = _beam_utils.collect_states(states_or_beams, docs)
|
||||
for i, (state, doc) in enumerate(zip(states, docs)):
|
||||
|
@ -359,8 +353,13 @@ cdef class Parser(TrainablePipe):
|
|||
self.c_transition_batch(&c_states[0], c_scores, scores.shape[1], scores.shape[0])
|
||||
return [state for state in states if not state.c.is_final()]
|
||||
|
||||
cdef void c_transition_batch(self, StateC** states, const float* scores,
|
||||
int nr_class, int batch_size) nogil:
|
||||
cdef void c_transition_batch(
|
||||
self,
|
||||
StateC** states,
|
||||
const float* scores,
|
||||
int nr_class,
|
||||
int batch_size
|
||||
) nogil:
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
with gil:
|
||||
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||
|
@ -380,7 +379,6 @@ cdef class Parser(TrainablePipe):
|
|||
free(is_valid)
|
||||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None):
|
||||
cdef StateClass state
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.)
|
||||
|
@ -419,8 +417,7 @@ cdef class Parser(TrainablePipe):
|
|||
if not states:
|
||||
return losses
|
||||
model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples])
|
||||
|
||||
all_states = list(states)
|
||||
|
||||
states_golds = list(zip(states, golds))
|
||||
n_moves = 0
|
||||
while states_golds:
|
||||
|
@ -500,8 +497,16 @@ cdef class Parser(TrainablePipe):
|
|||
del tutor
|
||||
return losses
|
||||
|
||||
def update_beam(self, examples, *, beam_width,
|
||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||
def update_beam(
|
||||
self,
|
||||
examples,
|
||||
*,
|
||||
beam_width,
|
||||
drop=0.,
|
||||
sgd=None,
|
||||
losses=None,
|
||||
beam_density=0.0
|
||||
):
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
if not states:
|
||||
return losses
|
||||
|
@ -531,8 +536,9 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
||||
cdef np.ndarray d_scores = numpy.zeros((len(states), self.moves.n_moves),
|
||||
dtype='f', order='C')
|
||||
cdef np.ndarray d_scores = numpy.zeros(
|
||||
(len(states), self.moves.n_moves), dtype='f', order='C'
|
||||
)
|
||||
c_d_scores = <float*>d_scores.data
|
||||
unseen_classes = self.model.attrs["unseen_classes"]
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
|
@ -542,8 +548,9 @@ cdef class Parser(TrainablePipe):
|
|||
for j in range(self.moves.n_moves):
|
||||
if costs[j] <= 0.0 and j in unseen_classes:
|
||||
unseen_classes.remove(j)
|
||||
cpu_log_loss(c_d_scores,
|
||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||
cpu_log_loss(
|
||||
c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]
|
||||
)
|
||||
c_d_scores += d_scores.shape[1]
|
||||
# Note that we don't normalize this. See comment in update() for why.
|
||||
if losses is not None:
|
||||
|
|
102
spacy/schemas.py
102
spacy/schemas.py
|
@ -16,19 +16,34 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConstrainedStr,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
ValidationError,
|
||||
create_model,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass
|
||||
try:
|
||||
from pydantic.v1 import (
|
||||
BaseModel,
|
||||
ConstrainedStr,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
ValidationError,
|
||||
create_model,
|
||||
validator,
|
||||
)
|
||||
from pydantic.v1.main import ModelMetaclass
|
||||
except ImportError:
|
||||
from pydantic import ( # type: ignore
|
||||
BaseModel,
|
||||
ConstrainedStr,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictFloat,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
ValidationError,
|
||||
create_model,
|
||||
validator,
|
||||
)
|
||||
from pydantic.main import ModelMetaclass # type: ignore
|
||||
from thinc.api import ConfigValidationError, Model, Optimizer
|
||||
from thinc.config import Promise
|
||||
|
||||
|
@ -397,6 +412,7 @@ class ConfigSchemaNlp(BaseModel):
|
|||
after_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after creation and before the pipeline is constructed")
|
||||
after_pipeline_creation: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after the pipeline is constructed")
|
||||
batch_size: Optional[int] = Field(..., title="Default batch size")
|
||||
vectors: Callable = Field(..., title="Vectors implementation")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
@ -465,66 +481,6 @@ CONFIG_SCHEMAS = {
|
|||
"initialize": ConfigSchemaInit,
|
||||
}
|
||||
|
||||
|
||||
# Project config Schema
|
||||
|
||||
|
||||
class ProjectConfigAssetGitItem(BaseModel):
|
||||
# fmt: off
|
||||
repo: StrictStr = Field(..., title="URL of Git repo to download from")
|
||||
path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)")
|
||||
branch: StrictStr = Field("master", title="Branch to clone from")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigAssetURL(BaseModel):
|
||||
# fmt: off
|
||||
dest: StrictStr = Field(..., title="Destination of downloaded asset")
|
||||
url: Optional[StrictStr] = Field(None, title="URL of asset")
|
||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
||||
description: StrictStr = Field("", title="Description of asset")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigAssetGit(BaseModel):
|
||||
# fmt: off
|
||||
git: ProjectConfigAssetGitItem = Field(..., title="Git repo information")
|
||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
||||
description: Optional[StrictStr] = Field(None, title="Description of asset")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigCommand(BaseModel):
|
||||
# fmt: off
|
||||
name: StrictStr = Field(..., title="Name of command")
|
||||
help: Optional[StrictStr] = Field(None, title="Command description")
|
||||
script: List[StrictStr] = Field([], title="List of CLI commands to run, in order")
|
||||
deps: List[StrictStr] = Field([], title="File dependencies required by this command")
|
||||
outputs: List[StrictStr] = Field([], title="Outputs produced by this command")
|
||||
outputs_no_cache: List[StrictStr] = Field([], title="Outputs not tracked by DVC (DVC only)")
|
||||
no_skip: bool = Field(False, title="Never skip this command, even if nothing changed")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
title = "A single named command specified in a project config"
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class ProjectConfigSchema(BaseModel):
|
||||
# fmt: off
|
||||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
||||
title: Optional[str] = Field(None, title="Project title")
|
||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
title = "Schema for project configuration file"
|
||||
|
||||
|
||||
# Recommendations for init config workflows
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=False
|
||||
cimport cython
|
||||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
from libcpp.set cimport set
|
||||
from murmurhash.mrmr cimport hash32, hash64
|
||||
|
||||
import srsly
|
||||
|
@ -20,9 +20,10 @@ cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
|||
try:
|
||||
out_hash[0] = key
|
||||
return True
|
||||
except:
|
||||
except: # no-cython-lint
|
||||
return False
|
||||
|
||||
|
||||
def get_string_id(key):
|
||||
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
||||
already an ID, return it.
|
||||
|
@ -87,7 +88,6 @@ cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) e
|
|||
cdef int n_length_bytes
|
||||
cdef int i
|
||||
cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str))
|
||||
cdef uint32_t ulength = length
|
||||
if length < sizeof(string.s):
|
||||
string.s[0] = <unsigned char>length
|
||||
memcpy(&string.s[1], chars, length)
|
||||
|
|
|
@ -52,7 +52,7 @@ cdef struct TokenC:
|
|||
|
||||
int sent_start
|
||||
int ent_iob
|
||||
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||
attr_t ent_kb_id
|
||||
hash_t ent_id
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ cdef enum symbol_t:
|
|||
ADV
|
||||
AUX
|
||||
CONJ
|
||||
CCONJ # U20
|
||||
CCONJ # U20
|
||||
DET
|
||||
INTJ
|
||||
NOUN
|
||||
|
@ -418,7 +418,7 @@ cdef enum symbol_t:
|
|||
ccomp
|
||||
complm
|
||||
conj
|
||||
cop # U20
|
||||
cop # U20
|
||||
csubj
|
||||
csubjpass
|
||||
dep
|
||||
|
@ -441,8 +441,8 @@ cdef enum symbol_t:
|
|||
num
|
||||
number
|
||||
oprd
|
||||
obj # U20
|
||||
obl # U20
|
||||
obj # U20
|
||||
obl # U20
|
||||
parataxis
|
||||
partmod
|
||||
pcomp
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: optimize.unpack_method_calls=False
|
||||
# cython: profile=False
|
||||
IDS = {
|
||||
"": NIL,
|
||||
"IS_ALPHA": IS_ALPHA,
|
||||
|
@ -96,7 +97,7 @@ IDS = {
|
|||
"ADV": ADV,
|
||||
"AUX": AUX,
|
||||
"CONJ": CONJ,
|
||||
"CCONJ": CCONJ, # U20
|
||||
"CCONJ": CCONJ, # U20
|
||||
"DET": DET,
|
||||
"INTJ": INTJ,
|
||||
"NOUN": NOUN,
|
||||
|
@ -421,7 +422,7 @@ IDS = {
|
|||
"ccomp": ccomp,
|
||||
"complm": complm,
|
||||
"conj": conj,
|
||||
"cop": cop, # U20
|
||||
"cop": cop, # U20
|
||||
"csubj": csubj,
|
||||
"csubjpass": csubjpass,
|
||||
"dep": dep,
|
||||
|
@ -444,8 +445,8 @@ IDS = {
|
|||
"num": num,
|
||||
"number": number,
|
||||
"oprd": oprd,
|
||||
"obj": obj, # U20
|
||||
"obl": obl, # U20
|
||||
"obj": obj, # U20
|
||||
"obl": obl, # U20
|
||||
"parataxis": parataxis,
|
||||
"partmod": partmod,
|
||||
"pcomp": pcomp,
|
||||
|
|
|
@ -216,6 +216,11 @@ def test_dependency_matcher_pattern_validation(en_vocab):
|
|||
pattern2 = copy.deepcopy(pattern)
|
||||
pattern2[1]["RIGHT_ID"] = "fox"
|
||||
matcher.add("FOUNDED", [pattern2])
|
||||
# invalid key
|
||||
with pytest.warns(UserWarning):
|
||||
pattern2 = copy.deepcopy(pattern)
|
||||
pattern2[1]["FOO"] = "BAR"
|
||||
matcher.add("FOUNDED", [pattern2])
|
||||
|
||||
|
||||
def test_dependency_matcher_callback(en_vocab, doc):
|
||||
|
|
|
@ -52,7 +52,8 @@ TEST_PATTERNS = [
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern", [[{"XX": "y"}, {"LENGTH": "2"}, {"TEXT": {"IN": 5}}]]
|
||||
"pattern",
|
||||
[[{"XX": "y"}], [{"LENGTH": "2"}], [{"TEXT": {"IN": 5}}], [{"text": {"in": 6}}]],
|
||||
)
|
||||
def test_matcher_pattern_validation(en_vocab, pattern):
|
||||
matcher = Matcher(en_vocab, validate=True)
|
||||
|
|
|
@ -4,14 +4,15 @@ from pathlib import Path
|
|||
|
||||
def test_build_dependencies():
|
||||
# Check that library requirements are pinned exactly the same across different setup files.
|
||||
# TODO: correct checks for numpy rather than ignoring
|
||||
libs_ignore_requirements = [
|
||||
"numpy",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"mock",
|
||||
"flake8",
|
||||
"hypothesis",
|
||||
"pre-commit",
|
||||
"cython-lint",
|
||||
"black",
|
||||
"isort",
|
||||
"mypy",
|
||||
|
@ -22,6 +23,7 @@ def test_build_dependencies():
|
|||
]
|
||||
# ignore language-specific packages that shouldn't be installed by all
|
||||
libs_ignore_setup = [
|
||||
"numpy",
|
||||
"fugashi",
|
||||
"natto-py",
|
||||
"pythainlp",
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
import pytest
|
||||
from pydantic import StrictBool
|
||||
|
||||
try:
|
||||
from pydantic.v1 import StrictBool
|
||||
except ImportError:
|
||||
from pydantic import StrictBool # type: ignore
|
||||
|
||||
from thinc.api import ConfigValidationError
|
||||
|
||||
from spacy.lang.en import English
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
import pytest
|
||||
from pydantic import StrictInt, StrictStr
|
||||
|
||||
try:
|
||||
from pydantic.v1 import StrictInt, StrictStr
|
||||
except ImportError:
|
||||
from pydantic import StrictInt, StrictStr # type: ignore
|
||||
|
||||
from thinc.api import ConfigValidationError, Linear, Model
|
||||
|
||||
import spacy
|
||||
|
|
|
@ -1,31 +1,20 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
import srsly
|
||||
from click import NoSuchOption
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from thinc.api import Config, ConfigValidationError
|
||||
from thinc.api import Config
|
||||
|
||||
import spacy
|
||||
from spacy import about
|
||||
from spacy import info as spacy_info
|
||||
from spacy.cli import info
|
||||
from spacy.cli._util import (
|
||||
download_file,
|
||||
is_subpath_of,
|
||||
load_project_config,
|
||||
parse_config_overrides,
|
||||
string_to_list,
|
||||
substitute_project_variables,
|
||||
upload_file,
|
||||
validate_project_commands,
|
||||
walk_directory,
|
||||
)
|
||||
from spacy.cli._util import parse_config_overrides, string_to_list, walk_directory
|
||||
from spacy.cli.apply import apply
|
||||
from spacy.cli.debug_data import (
|
||||
_compile_gold,
|
||||
|
@ -43,13 +32,11 @@ from spacy.cli.find_threshold import find_threshold
|
|||
from spacy.cli.init_config import RECOMMENDATIONS, fill_config, init_config
|
||||
from spacy.cli.init_pipeline import _init_labels
|
||||
from spacy.cli.package import _is_permitted_package_name, get_third_party_dependencies
|
||||
from spacy.cli.project.remote_storage import RemoteStorage
|
||||
from spacy.cli.project.run import _check_requirements
|
||||
from spacy.cli.validate import get_model_pkgs
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
from spacy.language import Language
|
||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||
from spacy.schemas import RecommendationSchema
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.tokens.span import Span
|
||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||
|
@ -134,25 +121,6 @@ def test_issue7055():
|
|||
assert "model" in filled_cfg["components"]["ner"]
|
||||
|
||||
|
||||
@pytest.mark.issue(11235)
|
||||
def test_issue11235():
|
||||
"""
|
||||
Test that the cli handles interpolation in the directory names correctly when loading project config.
|
||||
"""
|
||||
lang_var = "en"
|
||||
variables = {"lang": lang_var}
|
||||
commands = [{"name": "x", "script": ["hello ${vars.lang}"]}]
|
||||
directories = ["cfg", "${vars.lang}_model"]
|
||||
project = {"commands": commands, "vars": variables, "directories": directories}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
# Check that the directories are interpolated and created correctly
|
||||
assert os.path.exists(d / "cfg")
|
||||
assert os.path.exists(d / f"{lang_var}_model")
|
||||
assert cfg["commands"][0]["script"][0] == f"hello {lang_var}"
|
||||
|
||||
|
||||
@pytest.mark.issue(12566)
|
||||
@pytest.mark.parametrize(
|
||||
"factory,output_file",
|
||||
|
@ -225,6 +193,9 @@ def test_cli_info():
|
|||
raw_data = info(tmp_dir, exclude=[""])
|
||||
assert raw_data["lang"] == "nl"
|
||||
assert raw_data["components"] == ["textcat"]
|
||||
raw_data = spacy_info(tmp_dir, exclude=[""])
|
||||
assert raw_data["lang"] == "nl"
|
||||
assert raw_data["components"] == ["textcat"]
|
||||
|
||||
|
||||
def test_cli_converters_conllu_to_docs():
|
||||
|
@ -443,136 +414,6 @@ def test_cli_converters_conll_ner_to_docs():
|
|||
assert ent.text in ["New York City", "London"]
|
||||
|
||||
|
||||
def test_project_config_validation_full():
|
||||
config = {
|
||||
"vars": {"some_var": 20},
|
||||
"directories": ["assets", "configs", "corpus", "scripts", "training"],
|
||||
"assets": [
|
||||
{
|
||||
"dest": "x",
|
||||
"extra": True,
|
||||
"url": "https://example.com",
|
||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
||||
},
|
||||
{
|
||||
"dest": "y",
|
||||
"git": {
|
||||
"repo": "https://github.com/example/repo",
|
||||
"branch": "develop",
|
||||
"path": "y",
|
||||
},
|
||||
},
|
||||
{
|
||||
"dest": "z",
|
||||
"extra": False,
|
||||
"url": "https://example.com",
|
||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
||||
},
|
||||
],
|
||||
"commands": [
|
||||
{
|
||||
"name": "train",
|
||||
"help": "Train a model",
|
||||
"script": ["python -m spacy train config.cfg -o training"],
|
||||
"deps": ["config.cfg", "corpus/training.spcy"],
|
||||
"outputs": ["training/model-best"],
|
||||
},
|
||||
{"name": "test", "script": ["pytest", "custom.py"], "no_skip": True},
|
||||
],
|
||||
"workflows": {"all": ["train", "test"], "train": ["train"]},
|
||||
}
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert not errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"commands": [{"name": "a"}, {"name": "a"}]},
|
||||
{"commands": [{"name": "a"}], "workflows": {"a": []}},
|
||||
{"commands": [{"name": "a"}], "workflows": {"b": ["c"]}},
|
||||
],
|
||||
)
|
||||
def test_project_config_validation1(config):
|
||||
with pytest.raises(SystemExit):
|
||||
validate_project_commands(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,n_errors",
|
||||
[
|
||||
({"commands": {"a": []}}, 1),
|
||||
({"commands": [{"help": "..."}]}, 1),
|
||||
({"commands": [{"name": "a", "extra": "b"}]}, 1),
|
||||
({"commands": [{"extra": "b"}]}, 2),
|
||||
({"commands": [{"name": "a", "deps": [123]}]}, 1),
|
||||
],
|
||||
)
|
||||
def test_project_config_validation2(config, n_errors):
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert len(errors) == n_errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"int_value",
|
||||
[10, pytest.param("10", marks=pytest.mark.xfail)],
|
||||
)
|
||||
def test_project_config_interpolation(int_value):
|
||||
variables = {"a": int_value, "b": {"c": "foo", "d": True}}
|
||||
commands = [
|
||||
{"name": "x", "script": ["hello ${vars.a} ${vars.b.c}"]},
|
||||
{"name": "y", "script": ["${vars.b.c} ${vars.b.d}"]},
|
||||
]
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert type(cfg) == dict
|
||||
assert type(cfg["commands"]) == list
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 foo"
|
||||
assert cfg["commands"][1]["script"][0] == "foo true"
|
||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${vars.b.e}"]}]
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with pytest.raises(ConfigValidationError):
|
||||
substitute_project_variables(project)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
)
|
||||
def test_project_config_interpolation_override(greeting):
|
||||
variables = {"a": "world"}
|
||||
commands = [
|
||||
{"name": "x", "script": ["hello ${vars.a}"]},
|
||||
]
|
||||
overrides = {"vars.a": greeting}
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d, overrides=overrides)
|
||||
assert type(cfg) == dict
|
||||
assert type(cfg["commands"]) == list
|
||||
assert cfg["commands"][0]["script"][0] == f"hello {greeting}"
|
||||
|
||||
|
||||
def test_project_config_interpolation_env():
|
||||
variables = {"a": 10}
|
||||
env_var = "SPACY_TEST_FOO"
|
||||
env_vars = {"foo": env_var}
|
||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${env.foo}"]}]
|
||||
project = {"commands": commands, "vars": variables, "env": env_vars}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 "
|
||||
os.environ[env_var] = "123"
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 123"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected",
|
||||
[
|
||||
|
@ -782,21 +623,6 @@ def test_get_third_party_dependencies():
|
|||
get_third_party_dependencies(nlp.config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parent,child,expected",
|
||||
[
|
||||
("/tmp", "/tmp", True),
|
||||
("/tmp", "/", False),
|
||||
("/tmp", "/tmp/subdir", True),
|
||||
("/tmp", "/tmpdir", False),
|
||||
("/tmp", "/tmp/subdir/..", True),
|
||||
("/tmp", "/tmp/..", False),
|
||||
],
|
||||
)
|
||||
def test_is_subpath_of(parent, child, expected):
|
||||
assert is_subpath_of(parent, child) == expected
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"factory_name,pipe_name",
|
||||
|
@ -1042,60 +868,6 @@ def test_applycli_user_data():
|
|||
assert result[0]._.ext == val
|
||||
|
||||
|
||||
def test_local_remote_storage():
|
||||
with make_tempdir() as d:
|
||||
filename = "a.txt"
|
||||
|
||||
content_hashes = ("aaaa", "cccc", "bbbb")
|
||||
for i, content_hash in enumerate(content_hashes):
|
||||
# make sure that each subsequent file has a later timestamp
|
||||
if i > 0:
|
||||
time.sleep(1)
|
||||
content = f"{content_hash} content"
|
||||
loc_file = d / "root" / filename
|
||||
if not loc_file.parent.exists():
|
||||
loc_file.parent.mkdir(parents=True)
|
||||
with loc_file.open(mode="w") as file_:
|
||||
file_.write(content)
|
||||
|
||||
# push first version to remote storage
|
||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
||||
remote.push(filename, "aaaa", content_hash)
|
||||
|
||||
# retrieve with full hashes
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, command_hash="aaaa", content_hash=content_hash)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with command hash
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, command_hash="aaaa")
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with content hash
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, content_hash=content_hash)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with no hashes
|
||||
loc_file.unlink()
|
||||
remote.pull(filename)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_local_remote_storage_pull_missing():
|
||||
# pulling from a non-existent remote pulls nothing gracefully
|
||||
with make_tempdir() as d:
|
||||
filename = "a.txt"
|
||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
||||
assert remote.pull(filename, command_hash="aaaa") is None
|
||||
assert remote.pull(filename) is None
|
||||
|
||||
|
||||
def test_cli_find_threshold(capsys):
|
||||
def make_examples(nlp: Language) -> List[Example]:
|
||||
docs: List[Example] = []
|
||||
|
@ -1206,63 +978,6 @@ def test_cli_find_threshold(capsys):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
@pytest.mark.parametrize(
|
||||
"reqs,output",
|
||||
[
|
||||
[
|
||||
"""
|
||||
spacy
|
||||
|
||||
# comment
|
||||
|
||||
thinc""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
--some-flag
|
||||
spacy""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
--some-flag
|
||||
spacy; python_version >= '3.6'""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
spacyunknowndoesnotexist12345""",
|
||||
(True, False),
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_project_check_requirements(reqs, output):
|
||||
import pkg_resources
|
||||
|
||||
# excessive guard against unlikely package name
|
||||
try:
|
||||
pkg_resources.require("spacyunknowndoesnotexist12345")
|
||||
except pkg_resources.DistributionNotFound:
|
||||
assert output == _check_requirements([req.strip() for req in reqs.split("\n")])
|
||||
|
||||
|
||||
def test_upload_download_local_file():
|
||||
with make_tempdir() as d1, make_tempdir() as d2:
|
||||
filename = "f.txt"
|
||||
content = "content"
|
||||
local_file = d1 / filename
|
||||
remote_file = d2 / filename
|
||||
with local_file.open(mode="w") as file_:
|
||||
file_.write(content)
|
||||
upload_file(local_file, remote_file)
|
||||
local_file.unlink()
|
||||
download_file(remote_file, local_file)
|
||||
with local_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_walk_directory():
|
||||
with make_tempdir() as d:
|
||||
files = [
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
@ -6,7 +7,7 @@ import srsly
|
|||
from typer.testing import CliRunner
|
||||
|
||||
from spacy.cli._util import app, get_git_version
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.tokens import Doc, DocBin, Span
|
||||
|
||||
from .util import make_tempdir, normalize_whitespace
|
||||
|
||||
|
@ -213,6 +214,9 @@ def test_project_clone(options):
|
|||
assert (out / "README.md").is_file()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 12), reason="Python 3.12+ not supported for remotes"
|
||||
)
|
||||
def test_project_push_pull(project_dir):
|
||||
proj = dict(SAMPLE_PROJECT)
|
||||
remote = "xyz"
|
||||
|
@ -233,3 +237,196 @@ def test_project_push_pull(project_dir):
|
|||
result = CliRunner().invoke(app, ["project", "pull", remote, str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert test_file.is_file()
|
||||
|
||||
|
||||
def test_find_function_valid():
|
||||
# example of architecture in main code base
|
||||
function = "spacy.TextCatBOW.v2"
|
||||
result = CliRunner().invoke(app, ["find-function", function, "-r", "architectures"])
|
||||
assert f"Found registered function '{function}'" in result.stdout
|
||||
assert "textcat.py" in result.stdout
|
||||
|
||||
result = CliRunner().invoke(app, ["find-function", function])
|
||||
assert f"Found registered function '{function}'" in result.stdout
|
||||
assert "textcat.py" in result.stdout
|
||||
|
||||
# example of architecture in spacy-legacy
|
||||
function = "spacy.TextCatBOW.v1"
|
||||
result = CliRunner().invoke(app, ["find-function", function])
|
||||
assert f"Found registered function '{function}'" in result.stdout
|
||||
assert "spacy_legacy" in result.stdout
|
||||
assert "textcat.py" in result.stdout
|
||||
|
||||
|
||||
def test_find_function_invalid():
|
||||
# invalid registry
|
||||
function = "spacy.TextCatBOW.v2"
|
||||
registry = "foobar"
|
||||
result = CliRunner().invoke(
|
||||
app, ["find-function", function, "--registry", registry]
|
||||
)
|
||||
assert f"Unknown function registry: '{registry}'" in result.stdout
|
||||
|
||||
# invalid function
|
||||
function = "spacy.TextCatBOW.v666"
|
||||
result = CliRunner().invoke(app, ["find-function", function])
|
||||
assert f"Couldn't find registered function: '{function}'" in result.stdout
|
||||
|
||||
|
||||
example_words_1 = ["I", "like", "cats"]
|
||||
example_words_2 = ["I", "like", "dogs"]
|
||||
example_lemmas_1 = ["I", "like", "cat"]
|
||||
example_lemmas_2 = ["I", "like", "dog"]
|
||||
example_tags = ["PRP", "VBP", "NNS"]
|
||||
example_morphs = [
|
||||
"Case=Nom|Number=Sing|Person=1|PronType=Prs",
|
||||
"Tense=Pres|VerbForm=Fin",
|
||||
"Number=Plur",
|
||||
]
|
||||
example_deps = ["nsubj", "ROOT", "dobj"]
|
||||
example_pos = ["PRON", "VERB", "NOUN"]
|
||||
example_ents = ["O", "O", "I-ANIMAL"]
|
||||
example_spans = [(2, 3, "ANIMAL")]
|
||||
|
||||
TRAIN_EXAMPLE_1 = dict(
|
||||
words=example_words_1,
|
||||
lemmas=example_lemmas_1,
|
||||
tags=example_tags,
|
||||
morphs=example_morphs,
|
||||
deps=example_deps,
|
||||
heads=[1, 1, 1],
|
||||
pos=example_pos,
|
||||
ents=example_ents,
|
||||
spans=example_spans,
|
||||
cats={"CAT": 1.0, "DOG": 0.0},
|
||||
)
|
||||
TRAIN_EXAMPLE_2 = dict(
|
||||
words=example_words_2,
|
||||
lemmas=example_lemmas_2,
|
||||
tags=example_tags,
|
||||
morphs=example_morphs,
|
||||
deps=example_deps,
|
||||
heads=[1, 1, 1],
|
||||
pos=example_pos,
|
||||
ents=example_ents,
|
||||
spans=example_spans,
|
||||
cats={"CAT": 0.0, "DOG": 1.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"component,examples",
|
||||
[
|
||||
("tagger", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
("morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
("trainable_lemmatizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
("parser", [TRAIN_EXAMPLE_1] * 30),
|
||||
("ner", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
("spancat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
("textcat", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2]),
|
||||
],
|
||||
)
|
||||
def test_init_config_trainable(component, examples, en_vocab):
|
||||
if component == "textcat":
|
||||
train_docs = []
|
||||
for example in examples:
|
||||
doc = Doc(en_vocab, words=example["words"])
|
||||
doc.cats = example["cats"]
|
||||
train_docs.append(doc)
|
||||
elif component == "spancat":
|
||||
train_docs = []
|
||||
for example in examples:
|
||||
doc = Doc(en_vocab, words=example["words"])
|
||||
doc.spans["sc"] = [
|
||||
Span(doc, start, end, label) for start, end, label in example["spans"]
|
||||
]
|
||||
train_docs.append(doc)
|
||||
else:
|
||||
train_docs = []
|
||||
for example in examples:
|
||||
# cats, spans are not valid kwargs for instantiating a Doc
|
||||
example = {k: v for k, v in example.items() if k not in ("cats", "spans")}
|
||||
doc = Doc(en_vocab, **example)
|
||||
train_docs.append(doc)
|
||||
|
||||
with make_tempdir() as d_in:
|
||||
train_bin = DocBin(docs=train_docs)
|
||||
train_bin.to_disk(d_in / "train.spacy")
|
||||
dev_bin = DocBin(docs=train_docs)
|
||||
dev_bin.to_disk(d_in / "dev.spacy")
|
||||
init_config_result = CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"init",
|
||||
"config",
|
||||
f"{d_in}/config.cfg",
|
||||
"--lang",
|
||||
"en",
|
||||
"--pipeline",
|
||||
component,
|
||||
],
|
||||
)
|
||||
assert init_config_result.exit_code == 0
|
||||
train_result = CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"train",
|
||||
f"{d_in}/config.cfg",
|
||||
"--paths.train",
|
||||
f"{d_in}/train.spacy",
|
||||
"--paths.dev",
|
||||
f"{d_in}/dev.spacy",
|
||||
"--output",
|
||||
f"{d_in}/model",
|
||||
],
|
||||
)
|
||||
assert train_result.exit_code == 0
|
||||
assert Path(d_in / "model" / "model-last").exists()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"component,examples",
|
||||
[("tagger,parser,morphologizer", [TRAIN_EXAMPLE_1, TRAIN_EXAMPLE_2] * 15)],
|
||||
)
|
||||
def test_init_config_trainable_multiple(component, examples, en_vocab):
|
||||
train_docs = []
|
||||
for example in examples:
|
||||
example = {k: v for k, v in example.items() if k not in ("cats", "spans")}
|
||||
doc = Doc(en_vocab, **example)
|
||||
train_docs.append(doc)
|
||||
|
||||
with make_tempdir() as d_in:
|
||||
train_bin = DocBin(docs=train_docs)
|
||||
train_bin.to_disk(d_in / "train.spacy")
|
||||
dev_bin = DocBin(docs=train_docs)
|
||||
dev_bin.to_disk(d_in / "dev.spacy")
|
||||
init_config_result = CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"init",
|
||||
"config",
|
||||
f"{d_in}/config.cfg",
|
||||
"--lang",
|
||||
"en",
|
||||
"--pipeline",
|
||||
component,
|
||||
],
|
||||
)
|
||||
assert init_config_result.exit_code == 0
|
||||
train_result = CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"train",
|
||||
f"{d_in}/config.cfg",
|
||||
"--paths.train",
|
||||
f"{d_in}/train.spacy",
|
||||
"--paths.dev",
|
||||
f"{d_in}/dev.spacy",
|
||||
"--output",
|
||||
f"{d_in}/model",
|
||||
],
|
||||
)
|
||||
assert train_result.exit_code == 0
|
||||
assert Path(d_in / "model" / "model-last").exists()
|
||||
|
|
|
@ -113,7 +113,7 @@ def test_issue5838():
|
|||
doc = nlp(sample_text)
|
||||
doc.ents = [Span(doc, 7, 8, label="test")]
|
||||
html = displacy.render(doc, style="ent")
|
||||
found = html.count("</br>")
|
||||
found = html.count("<br>")
|
||||
assert found == 4
|
||||
|
||||
|
||||
|
@ -350,6 +350,78 @@ def test_displacy_render_wrapper(en_vocab):
|
|||
displacy.set_render_wrapper(lambda html: html)
|
||||
|
||||
|
||||
def test_displacy_render_manual_dep():
|
||||
"""Test displacy.render with manual data for dep style"""
|
||||
parsed_dep = {
|
||||
"words": [
|
||||
{"text": "This", "tag": "DT"},
|
||||
{"text": "is", "tag": "VBZ"},
|
||||
{"text": "a", "tag": "DT"},
|
||||
{"text": "sentence", "tag": "NN"},
|
||||
],
|
||||
"arcs": [
|
||||
{"start": 0, "end": 1, "label": "nsubj", "dir": "left"},
|
||||
{"start": 2, "end": 3, "label": "det", "dir": "left"},
|
||||
{"start": 1, "end": 3, "label": "attr", "dir": "right"},
|
||||
],
|
||||
"title": "Title",
|
||||
}
|
||||
html = displacy.render([parsed_dep], style="dep", manual=True)
|
||||
for word in parsed_dep["words"]:
|
||||
assert word["text"] in html
|
||||
assert word["tag"] in html
|
||||
|
||||
|
||||
def test_displacy_render_manual_ent():
|
||||
"""Test displacy.render with manual data for ent style"""
|
||||
parsed_ents = [
|
||||
{
|
||||
"text": "But Google is starting from behind.",
|
||||
"ents": [{"start": 4, "end": 10, "label": "ORG"}],
|
||||
},
|
||||
{
|
||||
"text": "But Google is starting from behind.",
|
||||
"ents": [{"start": -100, "end": 100, "label": "COMPANY"}],
|
||||
"title": "Title",
|
||||
},
|
||||
]
|
||||
|
||||
html = displacy.render(parsed_ents, style="ent", manual=True)
|
||||
for parsed_ent in parsed_ents:
|
||||
assert parsed_ent["ents"][0]["label"] in html
|
||||
if "title" in parsed_ent:
|
||||
assert parsed_ent["title"] in html
|
||||
|
||||
|
||||
def test_displacy_render_manual_span():
|
||||
"""Test displacy.render with manual data for span style"""
|
||||
parsed_spans = [
|
||||
{
|
||||
"text": "Welcome to the Bank of China.",
|
||||
"spans": [
|
||||
{"start_token": 3, "end_token": 6, "label": "ORG"},
|
||||
{"start_token": 5, "end_token": 6, "label": "GPE"},
|
||||
],
|
||||
"tokens": ["Welcome", "to", "the", "Bank", "of", "China", "."],
|
||||
},
|
||||
{
|
||||
"text": "Welcome to the Bank of China.",
|
||||
"spans": [
|
||||
{"start_token": 3, "end_token": 6, "label": "ORG"},
|
||||
{"start_token": 5, "end_token": 6, "label": "GPE"},
|
||||
],
|
||||
"tokens": ["Welcome", "to", "the", "Bank", "of", "China", "."],
|
||||
"title": "Title",
|
||||
},
|
||||
]
|
||||
|
||||
html = displacy.render(parsed_spans, style="span", manual=True)
|
||||
for parsed_span in parsed_spans:
|
||||
assert parsed_span["spans"][0]["label"] in html
|
||||
if "title" in parsed_span:
|
||||
assert parsed_span["title"] in html
|
||||
|
||||
|
||||
def test_displacy_options_case():
|
||||
ents = ["foo", "BAR"]
|
||||
colors = {"FOO": "red", "bar": "green"}
|
||||
|
@ -377,3 +449,22 @@ def test_displacy_manual_sorted_entities():
|
|||
|
||||
html = displacy.render(doc, style="ent", manual=True)
|
||||
assert html.find("FIRST") < html.find("SECOND")
|
||||
|
||||
|
||||
@pytest.mark.issue(12816)
|
||||
def test_issue12816(en_vocab) -> None:
|
||||
"""Test that displaCy's span visualizer escapes annotated HTML tags correctly."""
|
||||
# Create a doc containing an annotated word and an unannotated HTML tag
|
||||
doc = Doc(en_vocab, words=["test", "<TEST>"])
|
||||
doc.spans["sc"] = [Span(doc, 0, 1, label="test")]
|
||||
|
||||
# Verify that the HTML tag is escaped when unannotated
|
||||
html = displacy.render(doc, style="span")
|
||||
assert "<TEST>" in html
|
||||
|
||||
# Annotate the HTML tag
|
||||
doc.spans["sc"].append(Span(doc, 1, 2, label="test"))
|
||||
|
||||
# Verify that the HTML tag is still escaped
|
||||
html = displacy.render(doc, style="span")
|
||||
assert "<TEST>" in html
|
||||
|
|
|
@ -3,7 +3,12 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
try:
|
||||
from pydantic.v1 import ValidationError
|
||||
except ImportError:
|
||||
from pydantic import ValidationError # type: ignore
|
||||
|
||||
from thinc.api import (
|
||||
Config,
|
||||
ConfigValidationError,
|
||||
|
|
|
@ -31,24 +31,58 @@ cdef class Tokenizer:
|
|||
|
||||
cdef Doc _tokenize_affixes(self, str string, bint with_special_cases)
|
||||
cdef int _apply_special_cases(self, Doc doc) except -1
|
||||
cdef void _filter_special_spans(self, vector[SpanC] &original,
|
||||
vector[SpanC] &filtered, int doc_len) nogil
|
||||
cdef object _prepare_special_spans(self, Doc doc,
|
||||
vector[SpanC] &filtered)
|
||||
cdef int _retokenize_special_spans(self, Doc doc, TokenC* tokens,
|
||||
object span_data)
|
||||
cdef int _try_specials_and_cache(self, hash_t key, Doc tokens,
|
||||
int* has_special,
|
||||
bint with_special_cases) except -1
|
||||
cdef int _tokenize(self, Doc tokens, str span, hash_t key,
|
||||
int* has_special, bint with_special_cases) except -1
|
||||
cdef str _split_affixes(self, Pool mem, str string,
|
||||
vector[LexemeC*] *prefixes,
|
||||
vector[LexemeC*] *suffixes, int* has_special,
|
||||
bint with_special_cases)
|
||||
cdef int _attach_tokens(self, Doc tokens, str string,
|
||||
vector[LexemeC*] *prefixes,
|
||||
vector[LexemeC*] *suffixes, int* has_special,
|
||||
bint with_special_cases) except -1
|
||||
cdef int _save_cached(self, const TokenC* tokens, hash_t key,
|
||||
int* has_special, int n) except -1
|
||||
cdef void _filter_special_spans(
|
||||
self,
|
||||
vector[SpanC] &original,
|
||||
vector[SpanC] &filtered,
|
||||
int doc_len,
|
||||
) nogil
|
||||
cdef object _prepare_special_spans(
|
||||
self,
|
||||
Doc doc,
|
||||
vector[SpanC] &filtered,
|
||||
)
|
||||
cdef int _retokenize_special_spans(
|
||||
self,
|
||||
Doc doc,
|
||||
TokenC* tokens,
|
||||
object span_data,
|
||||
)
|
||||
cdef int _try_specials_and_cache(
|
||||
self,
|
||||
hash_t key,
|
||||
Doc tokens,
|
||||
int* has_special,
|
||||
bint with_special_cases,
|
||||
) except -1
|
||||
cdef int _tokenize(
|
||||
self,
|
||||
Doc tokens,
|
||||
str span,
|
||||
hash_t key,
|
||||
int* has_special,
|
||||
bint with_special_cases,
|
||||
) except -1
|
||||
cdef str _split_affixes(
|
||||
self,
|
||||
Pool mem,
|
||||
str string,
|
||||
vector[LexemeC*] *prefixes,
|
||||
vector[LexemeC*] *suffixes, int* has_special,
|
||||
bint with_special_cases,
|
||||
)
|
||||
cdef int _attach_tokens(
|
||||
self,
|
||||
Doc tokens,
|
||||
str string,
|
||||
vector[LexemeC*] *prefixes,
|
||||
vector[LexemeC*] *suffixes, int* has_special,
|
||||
bint with_special_cases,
|
||||
) except -1
|
||||
cdef int _save_cached(
|
||||
self,
|
||||
const TokenC* tokens,
|
||||
hash_t key,
|
||||
int* has_special,
|
||||
int n,
|
||||
) except -1
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: embedsignature=True, profile=True, binding=True
|
||||
# cython: embedsignature=True, binding=True
|
||||
cimport cython
|
||||
from cymem.cymem cimport Pool
|
||||
from cython.operator cimport dereference as deref
|
||||
|
@ -8,20 +8,18 @@ from libcpp.set cimport set as stdset
|
|||
from preshed.maps cimport PreshMap
|
||||
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from .lexeme cimport EMPTY_LEXEME
|
||||
from .strings cimport hash_string
|
||||
from .tokens.doc cimport Doc
|
||||
|
||||
from . import util
|
||||
from .attrs import intify_attrs
|
||||
from .errors import Errors, Warnings
|
||||
from .errors import Errors
|
||||
from .scorer import Scorer
|
||||
from .symbols import NORM, ORTH
|
||||
from .tokens import Span
|
||||
from .training import validate_examples
|
||||
from .util import get_words_and_spaces, registry
|
||||
from .util import get_words_and_spaces
|
||||
|
||||
|
||||
cdef class Tokenizer:
|
||||
|
@ -324,7 +322,7 @@ cdef class Tokenizer:
|
|||
cdef int span_start
|
||||
cdef int span_end
|
||||
while i < doc.length:
|
||||
if not i in span_data:
|
||||
if i not in span_data:
|
||||
tokens[i + offset] = doc.c[i]
|
||||
i += 1
|
||||
else:
|
||||
|
@ -395,12 +393,15 @@ cdef class Tokenizer:
|
|||
self._save_cached(&tokens.c[orig_size], orig_key, has_special,
|
||||
tokens.length - orig_size)
|
||||
|
||||
cdef str _split_affixes(self, Pool mem, str string,
|
||||
vector[const LexemeC*] *prefixes,
|
||||
vector[const LexemeC*] *suffixes,
|
||||
int* has_special,
|
||||
bint with_special_cases):
|
||||
cdef size_t i
|
||||
cdef str _split_affixes(
|
||||
self,
|
||||
Pool mem,
|
||||
str string,
|
||||
vector[const LexemeC*] *prefixes,
|
||||
vector[const LexemeC*] *suffixes,
|
||||
int* has_special,
|
||||
bint with_special_cases
|
||||
):
|
||||
cdef str prefix
|
||||
cdef str suffix
|
||||
cdef str minus_pre
|
||||
|
@ -445,10 +446,6 @@ cdef class Tokenizer:
|
|||
vector[const LexemeC*] *suffixes,
|
||||
int* has_special,
|
||||
bint with_special_cases) except -1:
|
||||
cdef bint specials_hit = 0
|
||||
cdef bint cache_hit = 0
|
||||
cdef int split, end
|
||||
cdef const LexemeC* const* lexemes
|
||||
cdef const LexemeC* lexeme
|
||||
cdef str span
|
||||
cdef int i
|
||||
|
@ -458,9 +455,11 @@ cdef class Tokenizer:
|
|||
if string:
|
||||
if self._try_specials_and_cache(hash_string(string), tokens, has_special, with_special_cases):
|
||||
pass
|
||||
elif (self.token_match and self.token_match(string)) or \
|
||||
(self.url_match and \
|
||||
self.url_match(string)):
|
||||
elif (
|
||||
(self.token_match and self.token_match(string)) or
|
||||
(self.url_match and self.url_match(string))
|
||||
):
|
||||
|
||||
# We're always saying 'no' to spaces here -- the caller will
|
||||
# fix up the outermost one, with reference to the original.
|
||||
# See Issue #859
|
||||
|
@ -821,7 +820,7 @@ cdef class Tokenizer:
|
|||
self.infix_finditer = None
|
||||
self.token_match = None
|
||||
self.url_match = None
|
||||
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||
util.from_bytes(bytes_data, deserializers, exclude)
|
||||
if "prefix_search" in data and isinstance(data["prefix_search"], str):
|
||||
self.prefix_search = re.compile(data["prefix_search"]).search
|
||||
if "suffix_search" in data and isinstance(data["suffix_search"], str):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# cython: infer_types=True, bounds_check=False, profile=True
|
||||
# cython: infer_types=True, bounds_check=False
|
||||
from cymem.cymem cimport Pool
|
||||
from libc.stdlib cimport free, malloc
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.string cimport memset
|
||||
|
||||
import numpy
|
||||
from thinc.api import get_array_module
|
||||
|
@ -10,7 +9,7 @@ from ..attrs cimport MORPH, NORM
|
|||
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
||||
from ..structs cimport LexemeC, TokenC
|
||||
from ..vocab cimport Vocab
|
||||
from .doc cimport Doc, set_children_from_heads, token_by_end, token_by_start
|
||||
from .doc cimport Doc, set_children_from_heads, token_by_start
|
||||
from .span cimport Span
|
||||
from .token cimport Token
|
||||
|
||||
|
@ -147,7 +146,7 @@ def _merge(Doc doc, merges):
|
|||
syntactic root of the span.
|
||||
RETURNS (Token): The first newly merged token.
|
||||
"""
|
||||
cdef int i, merge_index, start, end, token_index, current_span_index, current_offset, offset, span_index
|
||||
cdef int i, merge_index, start, token_index, current_span_index, current_offset, offset, span_index
|
||||
cdef Span span
|
||||
cdef const LexemeC* lex
|
||||
cdef TokenC* token
|
||||
|
@ -165,7 +164,6 @@ def _merge(Doc doc, merges):
|
|||
merges.sort(key=_get_start)
|
||||
for merge_index, (span, attributes) in enumerate(merges):
|
||||
start = span.start
|
||||
end = span.end
|
||||
spans.append(span)
|
||||
# House the new merged token where it starts
|
||||
token = &doc.c[start]
|
||||
|
@ -203,8 +201,9 @@ def _merge(Doc doc, merges):
|
|||
# for the merged region. To do this, we create a boolean array indicating
|
||||
# whether the row is to be deleted, then use numpy.delete
|
||||
if doc.tensor is not None and doc.tensor.size != 0:
|
||||
doc.tensor = _resize_tensor(doc.tensor,
|
||||
[(m[0].start, m[0].end) for m in merges])
|
||||
doc.tensor = _resize_tensor(
|
||||
doc.tensor, [(m[0].start, m[0].end) for m in merges]
|
||||
)
|
||||
# Memorize span roots and sets dependencies of the newly merged
|
||||
# tokens to the dependencies of their roots.
|
||||
span_roots = []
|
||||
|
@ -267,11 +266,11 @@ def _merge(Doc doc, merges):
|
|||
span_index += 1
|
||||
if span_index < len(spans) and i == spans[span_index].start:
|
||||
# First token in a span
|
||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||
offset += (spans[span_index].end - spans[span_index].start) - 1
|
||||
in_span = True
|
||||
if not in_span:
|
||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||
doc.c[i - offset] = doc.c[i] # move token to its place
|
||||
|
||||
for i in range(doc.length - offset, doc.length):
|
||||
memset(&doc.c[i], 0, sizeof(TokenC))
|
||||
|
@ -345,7 +344,11 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
|||
if to_process_tensor:
|
||||
xp = get_array_module(doc.tensor)
|
||||
if xp is numpy:
|
||||
doc.tensor = xp.append(doc.tensor, xp.zeros((nb_subtokens,doc.tensor.shape[1]), dtype="float32"), axis=0)
|
||||
doc.tensor = xp.append(
|
||||
doc.tensor,
|
||||
xp.zeros((nb_subtokens, doc.tensor.shape[1]), dtype="float32"),
|
||||
axis=0
|
||||
)
|
||||
else:
|
||||
shape = (doc.tensor.shape[0] + nb_subtokens, doc.tensor.shape[1])
|
||||
resized_array = xp.zeros(shape, dtype="float32")
|
||||
|
@ -367,7 +370,8 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
|||
token.norm = 0 # reset norm
|
||||
if to_process_tensor:
|
||||
# setting the tensors of the split tokens to array of zeros
|
||||
doc.tensor[token_index + i:token_index + i + 1] = xp.zeros((1,doc.tensor.shape[1]), dtype="float32")
|
||||
doc.tensor[token_index + i:token_index + i + 1] = \
|
||||
xp.zeros((1, doc.tensor.shape[1]), dtype="float32")
|
||||
# Update the character offset of the subtokens
|
||||
if i != 0:
|
||||
token.idx = orig_token.idx + idx_offset
|
||||
|
@ -455,7 +459,6 @@ def normalize_token_attrs(Vocab vocab, attrs):
|
|||
def set_token_attrs(Token py_token, attrs):
|
||||
cdef TokenC* token = py_token.c
|
||||
cdef const LexemeC* lex = token.lex
|
||||
cdef Doc doc = py_token.doc
|
||||
# Assign attributes
|
||||
for attr_name, attr_value in attrs.items():
|
||||
if attr_name == "_": # Set extension attributes
|
||||
|
|
|
@ -31,7 +31,7 @@ cdef int token_by_start(const TokenC* tokens, int length, int start_char) except
|
|||
cdef int token_by_end(const TokenC* tokens, int length, int end_char) except -2
|
||||
|
||||
|
||||
cdef int [:,:] _get_lca_matrix(Doc, int start, int end)
|
||||
cdef int [:, :] _get_lca_matrix(Doc, int start, int end)
|
||||
|
||||
|
||||
cdef class Doc:
|
||||
|
@ -61,7 +61,6 @@ cdef class Doc:
|
|||
cdef int length
|
||||
cdef int max_length
|
||||
|
||||
|
||||
cdef public object noun_chunks_iterator
|
||||
|
||||
cdef object __weakref__
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
|
@ -134,7 +135,12 @@ class Doc:
|
|||
def text(self) -> str: ...
|
||||
@property
|
||||
def text_with_ws(self) -> str: ...
|
||||
ents: Tuple[Span]
|
||||
# Ideally the getter would output Tuple[Span]
|
||||
# see https://github.com/python/mypy/issues/3004
|
||||
@property
|
||||
def ents(self) -> Sequence[Span]: ...
|
||||
@ents.setter
|
||||
def ents(self, value: Sequence[Span]) -> None: ...
|
||||
def set_ents(
|
||||
self,
|
||||
entities: List[Span],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, bounds_check=False, profile=True
|
||||
# cython: infer_types=True, bounds_check=False
|
||||
from typing import Set
|
||||
|
||||
cimport cython
|
||||
|
@ -43,14 +43,13 @@ from ..attrs cimport (
|
|||
attr_id_t,
|
||||
)
|
||||
from ..lexeme cimport EMPTY_LEXEME, Lexeme
|
||||
from ..typedefs cimport attr_t, flags_t
|
||||
from ..typedefs cimport attr_t
|
||||
from .token cimport Token
|
||||
|
||||
from .. import parts_of_speech, schemas, util
|
||||
from ..attrs import IDS, intify_attr
|
||||
from ..compat import copy_reg, pickle
|
||||
from ..compat import copy_reg
|
||||
from ..errors import Errors, Warnings
|
||||
from ..morphology import Morphology
|
||||
from ..util import get_words_and_spaces
|
||||
from ._retokenize import Retokenizer
|
||||
from .underscore import Underscore, get_ext_args
|
||||
|
@ -784,7 +783,7 @@ cdef class Doc:
|
|||
# TODO:
|
||||
# 1. Test basic data-driven ORTH gazetteer
|
||||
# 2. Test more nuanced date and currency regex
|
||||
cdef attr_t entity_type, kb_id, ent_id
|
||||
cdef attr_t kb_id, ent_id
|
||||
cdef int ent_start, ent_end
|
||||
ent_spans = []
|
||||
for ent_info in ents:
|
||||
|
@ -987,7 +986,6 @@ cdef class Doc:
|
|||
>>> np_array = doc.to_array([LOWER, POS, ENT_TYPE, IS_ALPHA])
|
||||
"""
|
||||
cdef int i, j
|
||||
cdef attr_id_t feature
|
||||
cdef np.ndarray[attr_t, ndim=2] output
|
||||
# Handle scalar/list inputs of strings/ints for py_attr_ids
|
||||
# See also #3064
|
||||
|
@ -999,8 +997,10 @@ cdef class Doc:
|
|||
py_attr_ids = [py_attr_ids]
|
||||
# Allow strings, e.g. 'lemma' or 'LEMMA'
|
||||
try:
|
||||
py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||
for id_ in py_attr_ids]
|
||||
py_attr_ids = [
|
||||
(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
|
||||
for id_ in py_attr_ids
|
||||
]
|
||||
except KeyError as msg:
|
||||
keys = [k for k in IDS.keys() if not k.startswith("FLAG")]
|
||||
raise KeyError(Errors.E983.format(dict="IDS", key=msg, keys=keys)) from None
|
||||
|
@ -1030,8 +1030,6 @@ cdef class Doc:
|
|||
DOCS: https://spacy.io/api/doc#count_by
|
||||
"""
|
||||
cdef int i
|
||||
cdef attr_t attr
|
||||
cdef size_t count
|
||||
|
||||
if counts is None:
|
||||
counts = Counter()
|
||||
|
@ -1093,7 +1091,6 @@ cdef class Doc:
|
|||
cdef int i, col
|
||||
cdef int32_t abs_head_index
|
||||
cdef attr_id_t attr_id
|
||||
cdef TokenC* tokens = self.c
|
||||
cdef int length = len(array)
|
||||
if length != len(self):
|
||||
raise ValueError(Errors.E971.format(array_length=length, doc_length=len(self)))
|
||||
|
@ -1225,7 +1222,7 @@ cdef class Doc:
|
|||
span.label,
|
||||
span.kb_id,
|
||||
span.id,
|
||||
span.text, # included as a check
|
||||
span.text, # included as a check
|
||||
))
|
||||
char_offset += len(doc.text)
|
||||
if len(doc) > 0 and ensure_whitespace and not doc[-1].is_space and not bool(doc[-1].whitespace_):
|
||||
|
@ -1508,7 +1505,6 @@ cdef class Doc:
|
|||
attributes are inherited from the syntactic root of the span.
|
||||
RETURNS (Token): The first newly merged token.
|
||||
"""
|
||||
cdef str tag, lemma, ent_type
|
||||
attr_len = len(attributes)
|
||||
span_len = len(spans)
|
||||
if not attr_len == span_len:
|
||||
|
@ -1624,7 +1620,6 @@ cdef class Doc:
|
|||
for token in char_span[1:]:
|
||||
token.is_sent_start = False
|
||||
|
||||
|
||||
for span_group in doc_json.get("spans", {}):
|
||||
spans = []
|
||||
for span in doc_json["spans"][span_group]:
|
||||
|
@ -1656,7 +1651,7 @@ cdef class Doc:
|
|||
start = token_by_char(self.c, self.length, token_data["start"])
|
||||
value = token_data["value"]
|
||||
self[start]._.set(token_attr, value)
|
||||
|
||||
|
||||
for span_attr in doc_json.get("underscore_span", {}):
|
||||
if not Span.has_extension(span_attr):
|
||||
Span.set_extension(span_attr)
|
||||
|
@ -1698,7 +1693,7 @@ cdef class Doc:
|
|||
token_data["dep"] = token.dep_
|
||||
token_data["head"] = token.head.i
|
||||
data["tokens"].append(token_data)
|
||||
|
||||
|
||||
if self.spans:
|
||||
data["spans"] = {}
|
||||
for span_group in self.spans:
|
||||
|
@ -1769,7 +1764,6 @@ cdef class Doc:
|
|||
output.fill(255)
|
||||
cdef int i, j, start_idx, end_idx
|
||||
cdef bytes byte_string
|
||||
cdef unsigned char utf8_char
|
||||
for i, byte_string in enumerate(byte_strings):
|
||||
j = 0
|
||||
start_idx = 0
|
||||
|
@ -1822,8 +1816,6 @@ cdef int token_by_char(const TokenC* tokens, int length, int char_idx) except -2
|
|||
|
||||
cdef int set_children_from_heads(TokenC* tokens, int start, int end) except -1:
|
||||
# note: end is exclusive
|
||||
cdef TokenC* head
|
||||
cdef TokenC* child
|
||||
cdef int i
|
||||
# Set number of left/right children to 0. We'll increment it in the loops.
|
||||
for i in range(start, end):
|
||||
|
@ -1923,7 +1915,7 @@ cdef int _get_tokens_lca(Token token_j, Token token_k):
|
|||
return -1
|
||||
|
||||
|
||||
cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
||||
cdef int [:, :] _get_lca_matrix(Doc doc, int start, int end):
|
||||
"""Given a doc and a start and end position defining a set of contiguous
|
||||
tokens within it, returns a matrix of Lowest Common Ancestors (LCA), where
|
||||
LCA[i, j] is the index of the lowest common ancestor among token i and j.
|
||||
|
@ -1936,7 +1928,7 @@ cdef int [:,:] _get_lca_matrix(Doc doc, int start, int end):
|
|||
RETURNS (int [:, :]): memoryview of numpy.array[ndim=2, dtype=numpy.int32],
|
||||
with shape (n, n), where n = len(doc).
|
||||
"""
|
||||
cdef int [:,:] lca_matrix
|
||||
cdef int [:, :] lca_matrix
|
||||
cdef int j, k
|
||||
n_tokens= end - start
|
||||
lca_mat = numpy.empty((n_tokens, n_tokens), dtype=numpy.int32)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user