mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 01:51:58 +03:00
Merge branch 'master' into feature/candidate-generation-by-docs
# Conflicts: # spacy/kb/kb_in_memory.pyx # spacy/pipeline/entity_linker.py # spacy/tests/doc/test_span.py # spacy/tests/pipeline/test_entity_linker.py # spacy/tokens/span.pyx
This commit is contained in:
commit
0a36f9d9e1
63
.github/azure-steps.yml
vendored
63
.github/azure-steps.yml
vendored
|
@ -52,51 +52,56 @@ steps:
|
|||
python -W error -c "import spacy"
|
||||
displayName: "Test import"
|
||||
|
||||
# - script: |
|
||||
# python -m spacy download ca_core_news_sm
|
||||
# python -m spacy download ca_core_news_md
|
||||
# python -c "import spacy; nlp=spacy.load('ca_core_news_sm'); doc=nlp('test')"
|
||||
# displayName: 'Test download CLI'
|
||||
# condition: eq(variables['python_version'], '3.8')
|
||||
#
|
||||
# - script: |
|
||||
# python -W error -c "import ca_core_news_sm; nlp = ca_core_news_sm.load(); doc=nlp('test')"
|
||||
# displayName: 'Test no warnings on load (#11713)'
|
||||
# condition: eq(variables['python_version'], '3.8')
|
||||
- script: |
|
||||
python -m spacy download ca_core_news_sm
|
||||
python -m spacy download ca_core_news_md
|
||||
python -c "import spacy; nlp=spacy.load('ca_core_news_sm'); doc=nlp('test')"
|
||||
displayName: 'Test download CLI'
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -W error -m spacy info ca_core_news_sm | grep -q download_url
|
||||
displayName: 'Test download_url in info CLI'
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -W error -c "import ca_core_news_sm; nlp = ca_core_news_sm.load(); doc=nlp('test')"
|
||||
displayName: 'Test no warnings on load (#11713)'
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -m spacy convert extra/example_data/ner_example_data/ner-token-per-line-conll2003.json .
|
||||
displayName: 'Test convert CLI'
|
||||
condition: eq(variables['python_version'], '3.8')
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -m spacy init config -p ner -l ca ner.cfg
|
||||
python -m spacy debug config ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy
|
||||
displayName: 'Test debug config CLI'
|
||||
condition: eq(variables['python_version'], '3.8')
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
# will have errors due to sparse data, check for summary in output
|
||||
python -m spacy debug data ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy | grep -q Summary
|
||||
displayName: 'Test debug data CLI'
|
||||
condition: eq(variables['python_version'], '3.8')
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -m spacy train ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy --training.max_steps 10 --gpu-id -1
|
||||
displayName: 'Test train CLI'
|
||||
condition: eq(variables['python_version'], '3.8')
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
# - script: |
|
||||
# python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_sm'}; config.to_disk('ner_source_sm.cfg')"
|
||||
# PYTHONWARNINGS="error,ignore::DeprecationWarning" python -m spacy assemble ner_source_sm.cfg output_dir
|
||||
# displayName: 'Test assemble CLI'
|
||||
# condition: eq(variables['python_version'], '3.8')
|
||||
#
|
||||
# - script: |
|
||||
# python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_md'}; config.to_disk('ner_source_md.cfg')"
|
||||
# python -m spacy assemble ner_source_md.cfg output_dir 2>&1 | grep -q W113
|
||||
# displayName: 'Test assemble CLI vectors warning'
|
||||
# condition: eq(variables['python_version'], '3.8')
|
||||
- script: |
|
||||
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_sm'}; config.to_disk('ner_source_sm.cfg')"
|
||||
PYTHONWARNINGS="error,ignore::DeprecationWarning" python -m spacy assemble ner_source_sm.cfg output_dir
|
||||
displayName: 'Test assemble CLI'
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_md'}; config.to_disk('ner_source_md.cfg')"
|
||||
python -m spacy assemble ner_source_md.cfg output_dir 2>&1 | grep -q W113
|
||||
displayName: 'Test assemble CLI vectors warning'
|
||||
condition: eq(variables['python_version'], '3.9')
|
||||
|
||||
- script: |
|
||||
python -m pip install -U -r requirements.txt
|
||||
|
@ -111,9 +116,3 @@ steps:
|
|||
python -m pytest --pyargs spacy
|
||||
displayName: "Run CPU tests with thinc-apple-ops"
|
||||
condition: and(startsWith(variables['imageName'], 'macos'), eq(variables['python.version'], '3.11'))
|
||||
|
||||
- script: |
|
||||
python .github/validate_universe_json.py website/meta/universe.json
|
||||
displayName: 'Test website/meta/universe.json'
|
||||
condition: eq(variables['python_version'], '3.8')
|
||||
|
||||
|
|
45
.github/workflows/autoblack.yml
vendored
45
.github/workflows/autoblack.yml
vendored
|
@ -1,45 +0,0 @@
|
|||
# GitHub Action that uses Black to reformat all Python code and submits a PR
|
||||
# in regular intervals. Inspired by: https://github.com/cclauss/autoblack
|
||||
|
||||
name: autoblack
|
||||
on:
|
||||
workflow_dispatch: # allow manual trigger
|
||||
schedule:
|
||||
- cron: '0 8 * * 5' # every Friday at 8am UTC
|
||||
|
||||
jobs:
|
||||
autoblack:
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
- uses: actions/setup-python@v4
|
||||
- run: pip install black
|
||||
- name: Auto-format code if needed
|
||||
run: black spacy
|
||||
# We can't run black --check here because that returns a non-zero excit
|
||||
# code and makes GitHub think the action failed
|
||||
- name: Check for modified files
|
||||
id: git-check
|
||||
run: echo modified=$(if git diff-index --quiet HEAD --; then echo "false"; else echo "true"; fi) >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.modified == 'true'
|
||||
uses: peter-evans/create-pull-request@v4
|
||||
with:
|
||||
title: Auto-format code with black
|
||||
labels: meta
|
||||
commit-message: Auto-format code with black
|
||||
committer: GitHub <noreply@github.com>
|
||||
author: explosion-bot <explosion-bot@users.noreply.github.com>
|
||||
body: _This PR is auto-generated._
|
||||
branch: autoblack
|
||||
delete-branch: true
|
||||
draft: false
|
||||
- name: Check outputs
|
||||
if: steps.git-check.outputs.modified == 'true'
|
||||
run: |
|
||||
echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}"
|
||||
echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}"
|
1
.github/workflows/explosionbot.yml
vendored
1
.github/workflows/explosionbot.yml
vendored
|
@ -8,6 +8,7 @@ on:
|
|||
|
||||
jobs:
|
||||
explosion-bot:
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Dump GitHub context
|
||||
|
|
1
.github/workflows/issue-manager.yml
vendored
1
.github/workflows/issue-manager.yml
vendored
|
@ -13,6 +13,7 @@ on:
|
|||
|
||||
jobs:
|
||||
issue-manager:
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: tiangolo/issue-manager@0.4.0
|
||||
|
|
1
.github/workflows/lock.yml
vendored
1
.github/workflows/lock.yml
vendored
|
@ -13,6 +13,7 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
action:
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: dessant/lock-threads@v4
|
||||
|
|
1
.github/workflows/spacy_universe_alert.yml
vendored
1
.github/workflows/spacy_universe_alert.yml
vendored
|
@ -7,6 +7,7 @@ on:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
|
173
.github/workflows/tests.yml
vendored
Normal file
173
.github/workflows/tests.yml
vendored
Normal file
|
@ -0,0 +1,173 @@
|
|||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches-ignore:
|
||||
- "spacy.io"
|
||||
- "nightly.spacy.io"
|
||||
- "v2.spacy.io"
|
||||
paths-ignore:
|
||||
- "*.md"
|
||||
- "*.mdx"
|
||||
- "website/**"
|
||||
- ".github/workflows/**"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
paths-ignore:
|
||||
- "*.md"
|
||||
- "*.mdx"
|
||||
- "website/**"
|
||||
|
||||
jobs:
|
||||
validate:
|
||||
name: Validate
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Configure Python version
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
architecture: x64
|
||||
|
||||
- name: black
|
||||
run: |
|
||||
python -m pip install black -c requirements.txt
|
||||
python -m black spacy --check
|
||||
- name: flake8
|
||||
run: |
|
||||
python -m pip install flake8==5.0.4
|
||||
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
||||
tests:
|
||||
name: Test
|
||||
needs: Validate
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python_version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-20.04
|
||||
python_version: "3.6"
|
||||
- os: windows-latest
|
||||
python_version: "3.7"
|
||||
- os: macos-latest
|
||||
python_version: "3.8"
|
||||
- os: ubuntu-latest
|
||||
python_version: "3.9"
|
||||
- os: windows-latest
|
||||
python_version: "3.10"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Configure Python version
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
architecture: x64
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U build pip setuptools
|
||||
python -m pip install -U -r requirements.txt
|
||||
|
||||
- name: Build sdist
|
||||
run: |
|
||||
python -m build --sdist
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
python -m mypy spacy
|
||||
if: matrix.python_version != '3.6'
|
||||
|
||||
- name: Delete source directory and .egg-info
|
||||
run: |
|
||||
rm -rf spacy *.egg-info
|
||||
shell: bash
|
||||
|
||||
- name: Uninstall all packages
|
||||
run: |
|
||||
python -m pip freeze
|
||||
python -m pip freeze --exclude pywin32 > installed.txt
|
||||
python -m pip uninstall -y -r installed.txt
|
||||
|
||||
- name: Install from sdist
|
||||
run: |
|
||||
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
|
||||
SPACY_NUM_BUILD_JOBS=2 python -m pip install dist/$SDIST
|
||||
shell: bash
|
||||
|
||||
- name: Test import
|
||||
run: python -W error -c "import spacy"
|
||||
|
||||
- name: "Test download CLI"
|
||||
run: |
|
||||
python -m spacy download ca_core_news_sm
|
||||
python -m spacy download ca_core_news_md
|
||||
python -c "import spacy; nlp=spacy.load('ca_core_news_sm'); doc=nlp('test')"
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test download_url in info CLI"
|
||||
run: |
|
||||
python -W error -m spacy info ca_core_news_sm | grep -q download_url
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test no warnings on load (#11713)"
|
||||
run: |
|
||||
python -W error -c "import ca_core_news_sm; nlp = ca_core_news_sm.load(); doc=nlp('test')"
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test convert CLI"
|
||||
run: |
|
||||
python -m spacy convert extra/example_data/ner_example_data/ner-token-per-line-conll2003.json .
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test debug config CLI"
|
||||
run: |
|
||||
python -m spacy init config -p ner -l ca ner.cfg
|
||||
python -m spacy debug config ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test debug data CLI"
|
||||
run: |
|
||||
# will have errors due to sparse data, check for summary in output
|
||||
python -m spacy debug data ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy | grep -q Summary
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test train CLI"
|
||||
run: |
|
||||
python -m spacy train ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy --training.max_steps 10 --gpu-id -1
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test assemble CLI"
|
||||
run: |
|
||||
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_sm'}; config.to_disk('ner_source_sm.cfg')"
|
||||
PYTHONWARNINGS="error,ignore::DeprecationWarning" python -m spacy assemble ner_source_sm.cfg output_dir
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Test assemble CLI vectors warning"
|
||||
run: |
|
||||
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_md'}; config.to_disk('ner_source_md.cfg')"
|
||||
python -m spacy assemble ner_source_md.cfg output_dir 2>&1 | grep -q W113
|
||||
if: matrix.python_version == '3.9'
|
||||
|
||||
- name: "Install test requirements"
|
||||
run: |
|
||||
python -m pip install -U -r requirements.txt
|
||||
|
||||
- name: "Run CPU tests"
|
||||
run: |
|
||||
python -m pytest --pyargs spacy -W error
|
||||
|
||||
- name: "Run CPU tests with thinc-apple-ops"
|
||||
run: |
|
||||
python -m pip install 'spacy[apple]'
|
||||
python -m pytest --pyargs spacy
|
||||
if: startsWith(matrix.os, 'macos') && matrix.python_version == '3.11'
|
33
.github/workflows/universe_validation.yml
vendored
Normal file
33
.github/workflows/universe_validation.yml
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
name: universe validation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches-ignore:
|
||||
- "spacy.io"
|
||||
- "nightly.spacy.io"
|
||||
- "v2.spacy.io"
|
||||
paths:
|
||||
- "website/meta/universe.json"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
paths:
|
||||
- "website/meta/universe.json"
|
||||
|
||||
jobs:
|
||||
validate:
|
||||
name: Validate
|
||||
if: github.repository_owner == 'explosion'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repo
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Configure Python version
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
architecture: x64
|
||||
|
||||
- name: Validate website/meta/universe.json
|
||||
run: |
|
||||
python .github/validate_universe_json.py website/meta/universe.json
|
10
.gitignore
vendored
10
.gitignore
vendored
|
@ -10,16 +10,6 @@ spacy/tests/package/setup.cfg
|
|||
spacy/tests/package/pyproject.toml
|
||||
spacy/tests/package/requirements.txt
|
||||
|
||||
# Website
|
||||
website/.cache/
|
||||
website/public/
|
||||
website/node_modules
|
||||
website/.npm
|
||||
website/logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
quickstart-training-generator.js
|
||||
|
||||
# Cython / C extensions
|
||||
cythonize.json
|
||||
spacy/*.html
|
||||
|
|
|
@ -173,6 +173,11 @@ formatting and [`flake8`](http://flake8.pycqa.org/en/latest/) for linting its
|
|||
Python modules. If you've built spaCy from source, you'll already have both
|
||||
tools installed.
|
||||
|
||||
As a general rule of thumb, we use f-strings for any formatting of strings.
|
||||
One exception are calls to Python's `logging` functionality.
|
||||
To avoid unnecessary string conversions in these cases, we use string formatting
|
||||
templates with `%s` and `%d` etc.
|
||||
|
||||
**⚠️ Note that formatting and linting is currently only possible for Python
|
||||
modules in `.py` files, not Cython modules in `.pyx` and `.pxd` files.**
|
||||
|
||||
|
|
|
@ -16,7 +16,10 @@ production-ready [**training system**](https://spacy.io/usage/training) and easy
|
|||
model packaging, deployment and workflow management. spaCy is commercial
|
||||
open-source software, released under the [MIT license](https://github.com/explosion/spaCy/blob/master/LICENSE).
|
||||
|
||||
💫 **Version 3.4 out now!**
|
||||
💥 **We'd love to hear more about your experience with spaCy!**
|
||||
[Fill out our survey here.](https://form.typeform.com/to/aMel9q9f)
|
||||
|
||||
💫 **Version 3.5 out now!**
|
||||
[Check out the release notes here.](https://github.com/explosion/spaCy/releases)
|
||||
|
||||
[](https://dev.azure.com/explosion-ai/public/_build?definitionId=8)
|
||||
|
|
|
@ -11,18 +11,28 @@ trigger:
|
|||
exclude:
|
||||
- "website/*"
|
||||
- "*.md"
|
||||
- "*.mdx"
|
||||
- ".github/workflows/*"
|
||||
pr:
|
||||
paths:
|
||||
exclude:
|
||||
- "*.md"
|
||||
- "*.mdx"
|
||||
- "website/docs/*"
|
||||
- "website/src/*"
|
||||
- "website/meta/*.tsx"
|
||||
- "website/meta/*.mjs"
|
||||
- "website/meta/languages.json"
|
||||
- "website/meta/site.json"
|
||||
- "website/meta/sidebars.json"
|
||||
- "website/meta/type-annotations.json"
|
||||
- "website/pages/*"
|
||||
- ".github/workflows/*"
|
||||
|
||||
jobs:
|
||||
# Perform basic checks for most important errors (syntax etc.) Uses the config
|
||||
# defined in .flake8 and overwrites the selected codes.
|
||||
# Check formatting and linting. Perform basic checks for most important errors
|
||||
# (syntax etc.) Uses the config defined in setup.cfg and overwrites the
|
||||
# selected codes.
|
||||
- job: "Validate"
|
||||
pool:
|
||||
vmImage: "ubuntu-latest"
|
||||
|
@ -30,10 +40,17 @@ jobs:
|
|||
- task: UsePythonVersion@0
|
||||
inputs:
|
||||
versionSpec: "3.7"
|
||||
- script: |
|
||||
pip install black -c requirements.txt
|
||||
python -m black spacy --check
|
||||
displayName: "black"
|
||||
- script: |
|
||||
pip install flake8==5.0.4
|
||||
python -m flake8 spacy --count --select=E901,E999,F821,F822,F823,W605 --show-source --statistics
|
||||
displayName: "flake8"
|
||||
- script: |
|
||||
python .github/validate_universe_json.py website/meta/universe.json
|
||||
displayName: 'Validate website/meta/universe.json'
|
||||
|
||||
- job: "Test"
|
||||
dependsOn: "Validate"
|
||||
|
|
|
@ -5,4 +5,5 @@ numpy==1.17.3; python_version=='3.8' and platform_machine!='aarch64'
|
|||
numpy==1.19.2; python_version=='3.8' and platform_machine=='aarch64'
|
||||
numpy==1.19.3; python_version=='3.9'
|
||||
numpy==1.21.3; python_version=='3.10'
|
||||
numpy; python_version>='3.11'
|
||||
numpy==1.23.2; python_version=='3.11'
|
||||
numpy; python_version>='3.12'
|
||||
|
|
|
@ -5,7 +5,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.1.0,<8.2.0",
|
||||
"thinc>=8.1.8,<8.2.0",
|
||||
"numpy>=1.15.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Our libraries
|
||||
spacy-legacy>=3.0.10,<3.1.0
|
||||
spacy-legacy>=3.0.11,<3.1.0
|
||||
spacy-loggers>=1.0.0,<2.0.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.0,<8.2.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
ml_datasets>=0.2.0,<0.3.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
|
@ -22,7 +22,7 @@ langcodes>=3.2.0,<4.0.0
|
|||
# Official Python utilities
|
||||
setuptools
|
||||
packaging>=20.0
|
||||
typing_extensions>=3.7.4.1,<4.2.0; python_version < "3.8"
|
||||
typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8"
|
||||
# Development dependencies
|
||||
pre-commit>=2.13.0
|
||||
cython>=0.25,<3.0
|
||||
|
@ -31,10 +31,10 @@ pytest-timeout>=1.3.0,<2.0.0
|
|||
mock>=2.0.0,<3.0.0
|
||||
flake8>=3.8.0,<6.0.0
|
||||
hypothesis>=3.27.0,<7.0.0
|
||||
mypy>=0.990,<0.1000; platform_machine != "aarch64" and python_version >= "3.7"
|
||||
mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7"
|
||||
types-dataclasses>=0.1.3; python_version < "3.7"
|
||||
types-mock>=0.1.1
|
||||
types-setuptools>=57.0.0
|
||||
types-requests
|
||||
types-setuptools>=57.0.0
|
||||
black>=22.0,<23.0
|
||||
black==22.3.0
|
||||
|
|
47
setup.cfg
47
setup.cfg
|
@ -22,6 +22,7 @@ classifiers =
|
|||
Programming Language :: Python :: 3.8
|
||||
Programming Language :: Python :: 3.9
|
||||
Programming Language :: Python :: 3.10
|
||||
Programming Language :: Python :: 3.11
|
||||
Topic :: Scientific/Engineering
|
||||
project_urls =
|
||||
Release notes = https://github.com/explosion/spaCy/releases
|
||||
|
@ -38,15 +39,15 @@ setup_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.1.0,<8.2.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
install_requires =
|
||||
# Our libraries
|
||||
spacy-legacy>=3.0.10,<3.1.0
|
||||
spacy-legacy>=3.0.11,<3.1.0
|
||||
spacy-loggers>=1.0.0,<2.0.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.1.0,<8.2.0
|
||||
thinc>=8.1.8,<8.2.0
|
||||
wasabi>=0.9.1,<1.2.0
|
||||
srsly>=2.4.3,<3.0.0
|
||||
catalogue>=2.0.6,<2.1.0
|
||||
|
@ -62,7 +63,7 @@ install_requires =
|
|||
# Official Python utilities
|
||||
setuptools
|
||||
packaging>=20.0
|
||||
typing_extensions>=3.7.4,<4.2.0; python_version < "3.8"
|
||||
typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8"
|
||||
langcodes>=3.2.0,<4.0.0
|
||||
|
||||
[options.entry_points]
|
||||
|
@ -73,45 +74,45 @@ console_scripts =
|
|||
lookups =
|
||||
spacy_lookups_data>=1.0.3,<1.1.0
|
||||
transformers =
|
||||
spacy_transformers>=1.1.2,<1.2.0
|
||||
spacy_transformers>=1.1.2,<1.3.0
|
||||
ray =
|
||||
spacy_ray>=0.1.0,<1.0.0
|
||||
cuda =
|
||||
cupy>=5.0.0b4,<12.0.0
|
||||
cupy>=5.0.0b4,<13.0.0
|
||||
cuda80 =
|
||||
cupy-cuda80>=5.0.0b4,<12.0.0
|
||||
cupy-cuda80>=5.0.0b4,<13.0.0
|
||||
cuda90 =
|
||||
cupy-cuda90>=5.0.0b4,<12.0.0
|
||||
cupy-cuda90>=5.0.0b4,<13.0.0
|
||||
cuda91 =
|
||||
cupy-cuda91>=5.0.0b4,<12.0.0
|
||||
cupy-cuda91>=5.0.0b4,<13.0.0
|
||||
cuda92 =
|
||||
cupy-cuda92>=5.0.0b4,<12.0.0
|
||||
cupy-cuda92>=5.0.0b4,<13.0.0
|
||||
cuda100 =
|
||||
cupy-cuda100>=5.0.0b4,<12.0.0
|
||||
cupy-cuda100>=5.0.0b4,<13.0.0
|
||||
cuda101 =
|
||||
cupy-cuda101>=5.0.0b4,<12.0.0
|
||||
cupy-cuda101>=5.0.0b4,<13.0.0
|
||||
cuda102 =
|
||||
cupy-cuda102>=5.0.0b4,<12.0.0
|
||||
cupy-cuda102>=5.0.0b4,<13.0.0
|
||||
cuda110 =
|
||||
cupy-cuda110>=5.0.0b4,<12.0.0
|
||||
cupy-cuda110>=5.0.0b4,<13.0.0
|
||||
cuda111 =
|
||||
cupy-cuda111>=5.0.0b4,<12.0.0
|
||||
cupy-cuda111>=5.0.0b4,<13.0.0
|
||||
cuda112 =
|
||||
cupy-cuda112>=5.0.0b4,<12.0.0
|
||||
cupy-cuda112>=5.0.0b4,<13.0.0
|
||||
cuda113 =
|
||||
cupy-cuda113>=5.0.0b4,<12.0.0
|
||||
cupy-cuda113>=5.0.0b4,<13.0.0
|
||||
cuda114 =
|
||||
cupy-cuda114>=5.0.0b4,<12.0.0
|
||||
cupy-cuda114>=5.0.0b4,<13.0.0
|
||||
cuda115 =
|
||||
cupy-cuda115>=5.0.0b4,<12.0.0
|
||||
cupy-cuda115>=5.0.0b4,<13.0.0
|
||||
cuda116 =
|
||||
cupy-cuda116>=5.0.0b4,<12.0.0
|
||||
cupy-cuda116>=5.0.0b4,<13.0.0
|
||||
cuda117 =
|
||||
cupy-cuda117>=5.0.0b4,<12.0.0
|
||||
cupy-cuda117>=5.0.0b4,<13.0.0
|
||||
cuda11x =
|
||||
cupy-cuda11x>=11.0.0,<12.0.0
|
||||
cupy-cuda11x>=11.0.0,<13.0.0
|
||||
cuda-autodetect =
|
||||
cupy-wheel>=11.0.0,<12.0.0
|
||||
cupy-wheel>=11.0.0,<13.0.0
|
||||
apple =
|
||||
thinc-apple-ops>=0.1.0.dev0,<1.0.0
|
||||
# Language tokenizers with external dependencies
|
||||
|
|
|
@ -4,6 +4,7 @@ from ._util import app, setup_cli # noqa: F401
|
|||
|
||||
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
|
||||
# are registered automatically and won't have to be imported here.
|
||||
from .benchmark_speed import benchmark_speed_cli # noqa: F401
|
||||
from .download import download # noqa: F401
|
||||
from .info import info # noqa: F401
|
||||
from .package import package # noqa: F401
|
||||
|
@ -16,6 +17,7 @@ from .debug_config import debug_config # noqa: F401
|
|||
from .debug_model import debug_model # noqa: F401
|
||||
from .debug_diff import debug_diff # noqa: F401
|
||||
from .evaluate import evaluate # noqa: F401
|
||||
from .apply import apply # noqa: F401
|
||||
from .convert import convert # noqa: F401
|
||||
from .init_pipeline import init_pipeline_cli # noqa: F401
|
||||
from .init_config import init_config, fill_config # noqa: F401
|
||||
|
|
|
@ -46,6 +46,7 @@ DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
|
|||
commands to check and validate your config files, training and evaluation data,
|
||||
and custom model implementations.
|
||||
"""
|
||||
BENCHMARK_HELP = """Commands for benchmarking pipelines."""
|
||||
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
||||
|
||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||
|
@ -54,12 +55,14 @@ Arg = typer.Argument
|
|||
Opt = typer.Option
|
||||
|
||||
app = typer.Typer(name=NAME, help=HELP)
|
||||
benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True)
|
||||
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
||||
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
||||
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
||||
|
||||
app.add_typer(project_cli)
|
||||
app.add_typer(debug_cli)
|
||||
app.add_typer(benchmark_cli)
|
||||
app.add_typer(init_cli)
|
||||
|
||||
|
||||
|
@ -87,9 +90,9 @@ def parse_config_overrides(
|
|||
cli_overrides = _parse_overrides(args, is_cli=True)
|
||||
if cli_overrides:
|
||||
keys = [k for k in cli_overrides if k not in env_overrides]
|
||||
logger.debug(f"Config overrides from CLI: {keys}")
|
||||
logger.debug("Config overrides from CLI: %s", keys)
|
||||
if env_overrides:
|
||||
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
|
||||
logger.debug("Config overrides from env variables: %s", list(env_overrides))
|
||||
return {**cli_overrides, **env_overrides}
|
||||
|
||||
|
||||
|
@ -582,6 +585,33 @@ def setup_gpu(use_gpu: int, silent=None) -> None:
|
|||
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
|
||||
|
||||
|
||||
def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]:
|
||||
"""Given a directory and a suffix, recursively find all files matching the suffix.
|
||||
Directories or files with names beginning with a . are ignored, but hidden flags on
|
||||
filesystems are not checked.
|
||||
When provided with a suffix `None`, there is no suffix-based filtering."""
|
||||
if not path.is_dir():
|
||||
return [path]
|
||||
paths = [path]
|
||||
locs = []
|
||||
seen = set()
|
||||
for path in paths:
|
||||
if str(path) in seen:
|
||||
continue
|
||||
seen.add(str(path))
|
||||
if path.parts[-1].startswith("."):
|
||||
continue
|
||||
elif path.is_dir():
|
||||
paths.extend(path.iterdir())
|
||||
elif suffix is not None and not path.parts[-1].endswith(suffix):
|
||||
continue
|
||||
else:
|
||||
locs.append(path)
|
||||
# It's good to sort these, in case the ordering messes up cache.
|
||||
locs.sort()
|
||||
return locs
|
||||
|
||||
|
||||
def _format_number(number: Union[int, float], ndigits: int = 2) -> str:
|
||||
"""Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s,
|
||||
as happens with `round(number, ndigits)`"""
|
||||
|
|
143
spacy/cli/apply.py
Normal file
143
spacy/cli/apply.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import tqdm
|
||||
import srsly
|
||||
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Iterable, cast, Union
|
||||
|
||||
from wasabi import msg
|
||||
|
||||
from ._util import app, Arg, Opt, setup_gpu, import_code, walk_directory
|
||||
|
||||
from ..tokens import Doc, DocBin
|
||||
from ..vocab import Vocab
|
||||
from ..util import ensure_path, load_model
|
||||
|
||||
|
||||
path_help = """Location of the documents to predict on.
|
||||
Can be a single file in .spacy format or a .jsonl file.
|
||||
Files with other extensions are treated as single plain text documents.
|
||||
If a directory is provided it is traversed recursively to grab
|
||||
all files to be processed.
|
||||
The files can be a mixture of .spacy, .jsonl and text files.
|
||||
If .jsonl is provided the specified field is going
|
||||
to be grabbed ("text" by default)."""
|
||||
|
||||
out_help = "Path to save the resulting .spacy file"
|
||||
code_help = (
|
||||
"Path to Python file with additional " "code (registered functions) to be imported"
|
||||
)
|
||||
gold_help = "Use gold preprocessing provided in the .spacy files"
|
||||
force_msg = (
|
||||
"The provided output file already exists. "
|
||||
"To force overwriting the output file, set the --force or -F flag."
|
||||
)
|
||||
|
||||
|
||||
DocOrStrStream = Union[Iterable[str], Iterable[Doc]]
|
||||
|
||||
|
||||
def _stream_docbin(path: Path, vocab: Vocab) -> Iterable[Doc]:
|
||||
"""
|
||||
Stream Doc objects from DocBin.
|
||||
"""
|
||||
docbin = DocBin().from_disk(path)
|
||||
for doc in docbin.get_docs(vocab):
|
||||
yield doc
|
||||
|
||||
|
||||
def _stream_jsonl(path: Path, field: str) -> Iterable[str]:
|
||||
"""
|
||||
Stream "text" field from JSONL. If the field "text" is
|
||||
not found it raises error.
|
||||
"""
|
||||
for entry in srsly.read_jsonl(path):
|
||||
if field not in entry:
|
||||
msg.fail(f"{path} does not contain the required '{field}' field.", exits=1)
|
||||
else:
|
||||
yield entry[field]
|
||||
|
||||
|
||||
def _stream_texts(paths: Iterable[Path]) -> Iterable[str]:
|
||||
"""
|
||||
Yields strings from text files in paths.
|
||||
"""
|
||||
for path in paths:
|
||||
with open(path, "r") as fin:
|
||||
text = fin.read()
|
||||
yield text
|
||||
|
||||
|
||||
@app.command("apply")
|
||||
def apply_cli(
|
||||
# fmt: off
|
||||
model: str = Arg(..., help="Model name or path"),
|
||||
data_path: Path = Arg(..., help=path_help, exists=True),
|
||||
output_file: Path = Arg(..., help=out_help, dir_okay=False),
|
||||
code_path: Optional[Path] = Opt(None, "--code", "-c", help=code_help),
|
||||
text_key: str = Opt("text", "--text-key", "-tk", help="Key containing text string for JSONL"),
|
||||
force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"),
|
||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU."),
|
||||
batch_size: int = Opt(1, "--batch-size", "-b", help="Batch size."),
|
||||
n_process: int = Opt(1, "--n-process", "-n", help="number of processors to use.")
|
||||
):
|
||||
"""
|
||||
Apply a trained pipeline to documents to get predictions.
|
||||
Expects a loadable spaCy pipeline and path to the data, which
|
||||
can be a directory or a file.
|
||||
The data files can be provided in multiple formats:
|
||||
1. .spacy files
|
||||
2. .jsonl files with a specified "field" to read the text from.
|
||||
3. Files with any other extension are assumed to be containing
|
||||
a single document.
|
||||
DOCS: https://spacy.io/api/cli#apply
|
||||
"""
|
||||
data_path = ensure_path(data_path)
|
||||
output_file = ensure_path(output_file)
|
||||
code_path = ensure_path(code_path)
|
||||
if output_file.exists() and not force_overwrite:
|
||||
msg.fail(force_msg, exits=1)
|
||||
if not data_path.exists():
|
||||
msg.fail(f"Couldn't find data path: {data_path}", exits=1)
|
||||
import_code(code_path)
|
||||
setup_gpu(use_gpu)
|
||||
apply(data_path, output_file, model, text_key, batch_size, n_process)
|
||||
|
||||
|
||||
def apply(
|
||||
data_path: Path,
|
||||
output_file: Path,
|
||||
model: str,
|
||||
json_field: str,
|
||||
batch_size: int,
|
||||
n_process: int,
|
||||
):
|
||||
docbin = DocBin(store_user_data=True)
|
||||
paths = walk_directory(data_path)
|
||||
if len(paths) == 0:
|
||||
docbin.to_disk(output_file)
|
||||
msg.warn(
|
||||
"Did not find data to process,"
|
||||
f" {data_path} seems to be an empty directory."
|
||||
)
|
||||
return
|
||||
nlp = load_model(model)
|
||||
msg.good(f"Loaded model {model}")
|
||||
vocab = nlp.vocab
|
||||
streams: List[DocOrStrStream] = []
|
||||
text_files = []
|
||||
for path in paths:
|
||||
if path.suffix == ".spacy":
|
||||
streams.append(_stream_docbin(path, vocab))
|
||||
elif path.suffix == ".jsonl":
|
||||
streams.append(_stream_jsonl(path, json_field))
|
||||
else:
|
||||
text_files.append(path)
|
||||
if len(text_files) > 0:
|
||||
streams.append(_stream_texts(text_files))
|
||||
datagen = cast(DocOrStrStream, chain(*streams))
|
||||
for doc in tqdm.tqdm(nlp.pipe(datagen, batch_size=batch_size, n_process=n_process)):
|
||||
docbin.add(doc)
|
||||
if output_file.suffix == "":
|
||||
output_file = output_file.with_suffix(".spacy")
|
||||
docbin.to_disk(output_file)
|
174
spacy/cli/benchmark_speed.py
Normal file
174
spacy/cli/benchmark_speed.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
from typing import Iterable, List, Optional
|
||||
import random
|
||||
from itertools import islice
|
||||
import numpy
|
||||
from pathlib import Path
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
from wasabi import msg
|
||||
|
||||
from .. import util
|
||||
from ..language import Language
|
||||
from ..tokens import Doc
|
||||
from ..training import Corpus
|
||||
from ._util import Arg, Opt, benchmark_cli, setup_gpu
|
||||
|
||||
|
||||
@benchmark_cli.command(
|
||||
"speed",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def benchmark_speed_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context,
|
||||
model: str = Arg(..., help="Model name or path"),
|
||||
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
|
||||
batch_size: Optional[int] = Opt(None, "--batch-size", "-b", min=1, help="Override the pipeline batch size"),
|
||||
no_shuffle: bool = Opt(False, "--no-shuffle", help="Do not shuffle benchmark data"),
|
||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||
n_batches: int = Opt(50, "--batches", help="Minimum number of batches to benchmark", min=30,),
|
||||
warmup_epochs: int = Opt(3, "--warmup", "-w", min=0, help="Number of iterations over the data for warmup"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
Benchmark a pipeline. Expects a loadable spaCy pipeline and benchmark
|
||||
data in the binary .spacy format.
|
||||
"""
|
||||
setup_gpu(use_gpu=use_gpu, silent=False)
|
||||
|
||||
nlp = util.load_model(model)
|
||||
batch_size = batch_size if batch_size is not None else nlp.batch_size
|
||||
corpus = Corpus(data_path)
|
||||
docs = [eg.predicted for eg in corpus(nlp)]
|
||||
|
||||
if len(docs) == 0:
|
||||
msg.fail("Cannot benchmark speed using an empty corpus.", exits=1)
|
||||
|
||||
print(f"Warming up for {warmup_epochs} epochs...")
|
||||
warmup(nlp, docs, warmup_epochs, batch_size)
|
||||
|
||||
print()
|
||||
print(f"Benchmarking {n_batches} batches...")
|
||||
wps = benchmark(nlp, docs, n_batches, batch_size, not no_shuffle)
|
||||
|
||||
print()
|
||||
print_outliers(wps)
|
||||
print_mean_with_ci(wps)
|
||||
|
||||
|
||||
# Lowercased, behaves as a context manager function.
|
||||
class time_context:
|
||||
"""Register the running time of a context."""
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.elapsed = time.perf_counter() - self.start
|
||||
|
||||
|
||||
class Quartiles:
|
||||
"""Calculate the q1, q2, q3 quartiles and the inter-quartile range (iqr)
|
||||
of a sample."""
|
||||
|
||||
q1: float
|
||||
q2: float
|
||||
q3: float
|
||||
iqr: float
|
||||
|
||||
def __init__(self, sample: numpy.ndarray) -> None:
|
||||
self.q1 = numpy.quantile(sample, 0.25)
|
||||
self.q2 = numpy.quantile(sample, 0.5)
|
||||
self.q3 = numpy.quantile(sample, 0.75)
|
||||
self.iqr = self.q3 - self.q1
|
||||
|
||||
|
||||
def annotate(
|
||||
nlp: Language, docs: List[Doc], batch_size: Optional[int]
|
||||
) -> numpy.ndarray:
|
||||
docs = nlp.pipe(tqdm(docs, unit="doc"), batch_size=batch_size)
|
||||
wps = []
|
||||
while True:
|
||||
with time_context() as elapsed:
|
||||
batch_docs = list(
|
||||
islice(docs, batch_size if batch_size else nlp.batch_size)
|
||||
)
|
||||
if len(batch_docs) == 0:
|
||||
break
|
||||
n_tokens = count_tokens(batch_docs)
|
||||
wps.append(n_tokens / elapsed.elapsed)
|
||||
|
||||
return numpy.array(wps)
|
||||
|
||||
|
||||
def benchmark(
|
||||
nlp: Language,
|
||||
docs: List[Doc],
|
||||
n_batches: int,
|
||||
batch_size: int,
|
||||
shuffle: bool,
|
||||
) -> numpy.ndarray:
|
||||
if shuffle:
|
||||
bench_docs = [
|
||||
nlp.make_doc(random.choice(docs).text)
|
||||
for _ in range(n_batches * batch_size)
|
||||
]
|
||||
else:
|
||||
bench_docs = [
|
||||
nlp.make_doc(docs[i % len(docs)].text)
|
||||
for i in range(n_batches * batch_size)
|
||||
]
|
||||
|
||||
return annotate(nlp, bench_docs, batch_size)
|
||||
|
||||
|
||||
def bootstrap(x, statistic=numpy.mean, iterations=10000) -> numpy.ndarray:
|
||||
"""Apply a statistic to repeated random samples of an array."""
|
||||
return numpy.fromiter(
|
||||
(
|
||||
statistic(numpy.random.choice(x, len(x), replace=True))
|
||||
for _ in range(iterations)
|
||||
),
|
||||
numpy.float64,
|
||||
)
|
||||
|
||||
|
||||
def count_tokens(docs: Iterable[Doc]) -> int:
|
||||
return sum(len(doc) for doc in docs)
|
||||
|
||||
|
||||
def print_mean_with_ci(sample: numpy.ndarray):
|
||||
mean = numpy.mean(sample)
|
||||
bootstrap_means = bootstrap(sample)
|
||||
bootstrap_means.sort()
|
||||
|
||||
# 95% confidence interval
|
||||
low = bootstrap_means[int(len(bootstrap_means) * 0.025)]
|
||||
high = bootstrap_means[int(len(bootstrap_means) * 0.975)]
|
||||
|
||||
print(f"Mean: {mean:.1f} words/s (95% CI: {low-mean:.1f} +{high-mean:.1f})")
|
||||
|
||||
|
||||
def print_outliers(sample: numpy.ndarray):
|
||||
quartiles = Quartiles(sample)
|
||||
|
||||
n_outliers = numpy.sum(
|
||||
(sample < (quartiles.q1 - 1.5 * quartiles.iqr))
|
||||
| (sample > (quartiles.q3 + 1.5 * quartiles.iqr))
|
||||
)
|
||||
n_extreme_outliers = numpy.sum(
|
||||
(sample < (quartiles.q1 - 3.0 * quartiles.iqr))
|
||||
| (sample > (quartiles.q3 + 3.0 * quartiles.iqr))
|
||||
)
|
||||
print(
|
||||
f"Outliers: {(100 * n_outliers) / len(sample):.1f}%, extreme outliers: {(100 * n_extreme_outliers) / len(sample)}%"
|
||||
)
|
||||
|
||||
|
||||
def warmup(
|
||||
nlp: Language, docs: List[Doc], warmup_epochs: int, batch_size: Optional[int]
|
||||
) -> numpy.ndarray:
|
||||
docs = warmup_epochs * docs
|
||||
return annotate(nlp, docs, batch_size)
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterable, Mapping, Optional, Any, List, Union
|
||||
from typing import Callable, Iterable, Mapping, Optional, Any, Union
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from wasabi import Printer
|
||||
|
@ -7,7 +7,7 @@ import re
|
|||
import sys
|
||||
import itertools
|
||||
|
||||
from ._util import app, Arg, Opt
|
||||
from ._util import app, Arg, Opt, walk_directory
|
||||
from ..training import docs_to_json
|
||||
from ..tokens import Doc, DocBin
|
||||
from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs
|
||||
|
@ -28,6 +28,8 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
|
|||
"json": json_to_docs,
|
||||
}
|
||||
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
# File types that can be written to stdout
|
||||
FILE_TYPES_STDOUT = ("json",)
|
||||
|
@ -49,7 +51,7 @@ def convert_cli(
|
|||
model: Optional[str] = Opt(None, "--model", "--base", "-b", help="Trained spaCy pipeline for sentence segmentation to use as base (for --seg-sents)"),
|
||||
morphology: bool = Opt(False, "--morphology", "-m", help="Enable appending morphology to tags"),
|
||||
merge_subtokens: bool = Opt(False, "--merge-subtokens", "-T", help="Merge CoNLL-U subtokens"),
|
||||
converter: str = Opt("auto", "--converter", "-c", help=f"Converter: {tuple(CONVERTERS.keys())}"),
|
||||
converter: str = Opt(AUTO, "--converter", "-c", help=f"Converter: {tuple(CONVERTERS.keys())}"),
|
||||
ner_map: Optional[Path] = Opt(None, "--ner-map", "-nm", help="NER tag mapping (as JSON-encoded dict of entity types)", exists=True),
|
||||
lang: Optional[str] = Opt(None, "--lang", "-l", help="Language (if tokenizer required)"),
|
||||
concatenate: bool = Opt(None, "--concatenate", "-C", help="Concatenate output to a single file"),
|
||||
|
@ -70,8 +72,8 @@ def convert_cli(
|
|||
output_dir: Union[str, Path] = "-" if output_dir == Path("-") else output_dir
|
||||
silent = output_dir == "-"
|
||||
msg = Printer(no_print=silent)
|
||||
verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
|
||||
converter = _get_converter(msg, converter, input_path)
|
||||
verify_cli_args(msg, input_path, output_dir, file_type.value, converter, ner_map)
|
||||
convert(
|
||||
input_path,
|
||||
output_dir,
|
||||
|
@ -100,7 +102,7 @@ def convert(
|
|||
model: Optional[str] = None,
|
||||
morphology: bool = False,
|
||||
merge_subtokens: bool = False,
|
||||
converter: str = "auto",
|
||||
converter: str,
|
||||
ner_map: Optional[Path] = None,
|
||||
lang: Optional[str] = None,
|
||||
concatenate: bool = False,
|
||||
|
@ -189,33 +191,6 @@ def autodetect_ner_format(input_data: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def walk_directory(path: Path, converter: str) -> List[Path]:
|
||||
if not path.is_dir():
|
||||
return [path]
|
||||
paths = [path]
|
||||
locs = []
|
||||
seen = set()
|
||||
for path in paths:
|
||||
if str(path) in seen:
|
||||
continue
|
||||
seen.add(str(path))
|
||||
if path.parts[-1].startswith("."):
|
||||
continue
|
||||
elif path.is_dir():
|
||||
paths.extend(path.iterdir())
|
||||
elif converter == "json" and not path.parts[-1].endswith("json"):
|
||||
continue
|
||||
elif converter == "conll" and not path.parts[-1].endswith("conll"):
|
||||
continue
|
||||
elif converter == "iob" and not path.parts[-1].endswith("iob"):
|
||||
continue
|
||||
else:
|
||||
locs.append(path)
|
||||
# It's good to sort these, in case the ordering messes up cache.
|
||||
locs.sort()
|
||||
return locs
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
msg: Printer,
|
||||
input_path: Path,
|
||||
|
@ -239,18 +214,22 @@ def verify_cli_args(
|
|||
input_locs = walk_directory(input_path, converter)
|
||||
if len(input_locs) == 0:
|
||||
msg.fail("No input files in directory", input_path, exits=1)
|
||||
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
|
||||
if converter == "auto" and len(file_types) >= 2:
|
||||
file_types_str = ",".join(file_types)
|
||||
msg.fail("All input files must be same type", file_types_str, exits=1)
|
||||
if converter != "auto" and converter not in CONVERTERS:
|
||||
if converter not in CONVERTERS:
|
||||
msg.fail(f"Can't find converter for {converter}", exits=1)
|
||||
|
||||
|
||||
def _get_converter(msg, converter, input_path: Path):
|
||||
if input_path.is_dir():
|
||||
input_path = walk_directory(input_path, converter)[0]
|
||||
if converter == "auto":
|
||||
if converter == AUTO:
|
||||
input_locs = walk_directory(input_path, suffix=None)
|
||||
file_types = list(set([loc.suffix[1:] for loc in input_locs]))
|
||||
if len(file_types) >= 2:
|
||||
file_types_str = ",".join(file_types)
|
||||
msg.fail("All input files must be same type", file_types_str, exits=1)
|
||||
input_path = input_locs[0]
|
||||
else:
|
||||
input_path = walk_directory(input_path, suffix=converter)[0]
|
||||
if converter == AUTO:
|
||||
converter = input_path.suffix[1:]
|
||||
if converter == "ner" or converter == "iob":
|
||||
with input_path.open(encoding="utf8") as file_:
|
||||
|
|
|
@ -7,6 +7,7 @@ import srsly
|
|||
from wasabi import Printer, MESSAGES, msg
|
||||
import typer
|
||||
import math
|
||||
import numpy
|
||||
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli, _format_number
|
||||
|
@ -17,6 +18,7 @@ from ..pipeline import TrainablePipe
|
|||
from ..pipeline._parser_internals import nonproj
|
||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||
from ..pipeline import Morphologizer, SpanCategorizer
|
||||
from ..pipeline._edit_tree_internals.edit_trees import EditTrees
|
||||
from ..morphology import Morphology
|
||||
from ..language import Language
|
||||
from ..util import registry, resolve_dot_names
|
||||
|
@ -520,9 +522,13 @@ def debug_data(
|
|||
|
||||
if "tagger" in factory_names:
|
||||
msg.divider("Part-of-speech Tagging")
|
||||
label_list = [label for label in gold_train_data["tags"]]
|
||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||
label_list, counts = zip(*gold_train_data["tags"].items())
|
||||
msg.info(f"{len(label_list)} label(s) in train data")
|
||||
p = numpy.array(counts)
|
||||
p = p / p.sum()
|
||||
norm_entropy = (-p * numpy.log2(p)).sum() / numpy.log2(len(label_list))
|
||||
msg.info(f"{norm_entropy} is the normalised label entropy")
|
||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||
labels = set(label_list)
|
||||
missing_labels = model_labels - labels
|
||||
if missing_labels:
|
||||
|
@ -671,6 +677,59 @@ def debug_data(
|
|||
f"Found {gold_train_data['n_cycles']} projectivized train sentence(s) with cycles"
|
||||
)
|
||||
|
||||
if "trainable_lemmatizer" in factory_names:
|
||||
msg.divider("Trainable Lemmatizer")
|
||||
trees_train: Set[str] = gold_train_data["lemmatizer_trees"]
|
||||
trees_dev: Set[str] = gold_dev_data["lemmatizer_trees"]
|
||||
# This is necessary context when someone is attempting to interpret whether the
|
||||
# number of trees exclusively in the dev set is meaningful.
|
||||
msg.info(f"{len(trees_train)} lemmatizer trees generated from training data")
|
||||
msg.info(f"{len(trees_dev)} lemmatizer trees generated from dev data")
|
||||
dev_not_train = trees_dev - trees_train
|
||||
|
||||
if len(dev_not_train) != 0:
|
||||
pct = len(dev_not_train) / len(trees_dev)
|
||||
msg.info(
|
||||
f"{len(dev_not_train)} lemmatizer trees ({pct*100:.1f}% of dev trees)"
|
||||
" were found exclusively in the dev data."
|
||||
)
|
||||
else:
|
||||
# Would we ever expect this case? It seems like it would be pretty rare,
|
||||
# and we might actually want a warning?
|
||||
msg.info("All trees in dev data present in training data.")
|
||||
|
||||
if gold_train_data["n_low_cardinality_lemmas"] > 0:
|
||||
n = gold_train_data["n_low_cardinality_lemmas"]
|
||||
msg.warn(f"{n} training docs with 0 or 1 unique lemmas.")
|
||||
|
||||
if gold_dev_data["n_low_cardinality_lemmas"] > 0:
|
||||
n = gold_dev_data["n_low_cardinality_lemmas"]
|
||||
msg.warn(f"{n} dev docs with 0 or 1 unique lemmas.")
|
||||
|
||||
if gold_train_data["no_lemma_annotations"] > 0:
|
||||
n = gold_train_data["no_lemma_annotations"]
|
||||
msg.warn(f"{n} training docs with no lemma annotations.")
|
||||
else:
|
||||
msg.good("All training docs have lemma annotations.")
|
||||
|
||||
if gold_dev_data["no_lemma_annotations"] > 0:
|
||||
n = gold_dev_data["no_lemma_annotations"]
|
||||
msg.warn(f"{n} dev docs with no lemma annotations.")
|
||||
else:
|
||||
msg.good("All dev docs have lemma annotations.")
|
||||
|
||||
if gold_train_data["partial_lemma_annotations"] > 0:
|
||||
n = gold_train_data["partial_lemma_annotations"]
|
||||
msg.info(f"{n} training docs with partial lemma annotations.")
|
||||
else:
|
||||
msg.good("All training docs have complete lemma annotations.")
|
||||
|
||||
if gold_dev_data["partial_lemma_annotations"] > 0:
|
||||
n = gold_dev_data["partial_lemma_annotations"]
|
||||
msg.info(f"{n} dev docs with partial lemma annotations.")
|
||||
else:
|
||||
msg.good("All dev docs have complete lemma annotations.")
|
||||
|
||||
msg.divider("Summary")
|
||||
good_counts = msg.counts[MESSAGES.GOOD]
|
||||
warn_counts = msg.counts[MESSAGES.WARN]
|
||||
|
@ -732,7 +791,13 @@ def _compile_gold(
|
|||
"n_cats_multilabel": 0,
|
||||
"n_cats_bad_values": 0,
|
||||
"texts": set(),
|
||||
"lemmatizer_trees": set(),
|
||||
"no_lemma_annotations": 0,
|
||||
"partial_lemma_annotations": 0,
|
||||
"n_low_cardinality_lemmas": 0,
|
||||
}
|
||||
if "trainable_lemmatizer" in factory_names:
|
||||
trees = EditTrees(nlp.vocab.strings)
|
||||
for eg in examples:
|
||||
gold = eg.reference
|
||||
doc = eg.predicted
|
||||
|
@ -862,6 +927,25 @@ def _compile_gold(
|
|||
data["n_nonproj"] += 1
|
||||
if nonproj.contains_cycle(aligned_heads):
|
||||
data["n_cycles"] += 1
|
||||
if "trainable_lemmatizer" in factory_names:
|
||||
# from EditTreeLemmatizer._labels_from_data
|
||||
if all(token.lemma == 0 for token in gold):
|
||||
data["no_lemma_annotations"] += 1
|
||||
continue
|
||||
if any(token.lemma == 0 for token in gold):
|
||||
data["partial_lemma_annotations"] += 1
|
||||
lemma_set = set()
|
||||
for token in gold:
|
||||
if token.lemma != 0:
|
||||
lemma_set.add(token.lemma)
|
||||
tree_id = trees.add(token.text, token.lemma_)
|
||||
tree_str = trees.tree_to_str(tree_id)
|
||||
data["lemmatizer_trees"].add(tree_str)
|
||||
# We want to identify cases where lemmas aren't assigned
|
||||
# or are all assigned the same value, as this would indicate
|
||||
# an issue since we're expecting a large set of lemmas
|
||||
if len(lemma_set) < 2 and len(gold) > 1:
|
||||
data["n_low_cardinality_lemmas"] += 1
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -7,12 +7,15 @@ from thinc.api import fix_random_seed
|
|||
|
||||
from ..training import Corpus
|
||||
from ..tokens import Doc
|
||||
from ._util import app, Arg, Opt, setup_gpu, import_code
|
||||
from ._util import app, Arg, Opt, setup_gpu, import_code, benchmark_cli
|
||||
from ..scorer import Scorer
|
||||
from .. import util
|
||||
from .. import displacy
|
||||
|
||||
|
||||
@benchmark_cli.command(
|
||||
"accuracy",
|
||||
)
|
||||
@app.command("evaluate")
|
||||
def evaluate_cli(
|
||||
# fmt: off
|
||||
|
@ -36,7 +39,7 @@ def evaluate_cli(
|
|||
dependency parses in a HTML file, set as output directory as the
|
||||
displacy_path argument.
|
||||
|
||||
DOCS: https://spacy.io/api/cli#evaluate
|
||||
DOCS: https://spacy.io/api/cli#benchmark-accuracy
|
||||
"""
|
||||
import_code(code_path)
|
||||
evaluate(
|
||||
|
|
|
@ -35,7 +35,7 @@ def find_threshold_cli(
|
|||
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
|
||||
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional, Dict, Any, Union, List
|
||||
import platform
|
||||
import pkg_resources
|
||||
import json
|
||||
from pathlib import Path
|
||||
from wasabi import Printer, MarkdownRenderer
|
||||
|
@ -10,6 +9,7 @@ from ._util import app, Arg, Opt, string_to_list
|
|||
from .download import get_model_filename, get_latest_version
|
||||
from .. import util
|
||||
from .. import about
|
||||
from ..compat import importlib_metadata
|
||||
|
||||
|
||||
@app.command("info")
|
||||
|
@ -137,15 +137,14 @@ def info_installed_model_url(model: str) -> Optional[str]:
|
|||
dist-info available.
|
||||
"""
|
||||
try:
|
||||
dist = pkg_resources.get_distribution(model)
|
||||
data = json.loads(dist.get_metadata("direct_url.json"))
|
||||
return data["url"]
|
||||
except pkg_resources.DistributionNotFound:
|
||||
# no such package
|
||||
return None
|
||||
dist = importlib_metadata.distribution(model)
|
||||
text = dist.read_text("direct_url.json")
|
||||
if isinstance(text, str):
|
||||
data = json.loads(text)
|
||||
return data["url"]
|
||||
except Exception:
|
||||
# something else, like no file or invalid JSON
|
||||
return None
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def info_model_url(model: str) -> Dict[str, Any]:
|
||||
|
|
|
@ -252,7 +252,7 @@ def get_third_party_dependencies(
|
|||
raise regerr from None
|
||||
module_name = func_info.get("module") # type: ignore[attr-defined]
|
||||
if module_name: # the code is part of a module, not a --code file
|
||||
modules.add(func_info["module"].split(".")[0]) # type: ignore[index]
|
||||
modules.add(func_info["module"].split(".")[0]) # type: ignore[union-attr]
|
||||
dependencies = []
|
||||
for module_name in modules:
|
||||
if module_name in distributions:
|
||||
|
|
|
@ -23,6 +23,7 @@ def pretrain_cli(
|
|||
resume_path: Optional[Path] = Opt(None, "--resume-path", "-r", help="Path to pretrained weights from which to resume pretraining"),
|
||||
epoch_resume: Optional[int] = Opt(None, "--epoch-resume", "-er", help="The epoch to resume counting from when using --resume-path. Prevents unintended overwriting of existing weight files."),
|
||||
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
|
||||
skip_last: bool = Opt(False, "--skip-last", "-L", help="Skip saving model-last.bin"),
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
|
@ -74,6 +75,7 @@ def pretrain_cli(
|
|||
epoch_resume=epoch_resume,
|
||||
use_gpu=use_gpu,
|
||||
silent=False,
|
||||
skip_last=skip_last,
|
||||
)
|
||||
msg.good("Successfully finished pretrain")
|
||||
|
||||
|
|
|
@ -39,14 +39,17 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
|
|||
# in the list.
|
||||
while commands:
|
||||
for i, cmd in enumerate(list(commands)):
|
||||
logger.debug(f"CMD: {cmd['name']}.")
|
||||
logger.debug("CMD: %s.", cmd["name"])
|
||||
deps = [project_dir / dep for dep in cmd.get("deps", [])]
|
||||
if all(dep.exists() for dep in deps):
|
||||
cmd_hash = get_command_hash("", "", deps, cmd["script"])
|
||||
for output_path in cmd.get("outputs", []):
|
||||
url = storage.pull(output_path, command_hash=cmd_hash)
|
||||
logger.debug(
|
||||
f"URL: {url} for {output_path} with command hash {cmd_hash}"
|
||||
"URL: %s for %s with command hash %s",
|
||||
url,
|
||||
output_path,
|
||||
cmd_hash,
|
||||
)
|
||||
yield url, output_path
|
||||
|
||||
|
@ -58,7 +61,7 @@ def project_pull(project_dir: Path, remote: str, *, verbose: bool = False):
|
|||
commands.pop(i)
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Dependency missing. Skipping {cmd['name']} outputs.")
|
||||
logger.debug("Dependency missing. Skipping %s outputs.", cmd["name"])
|
||||
else:
|
||||
# If we didn't break the for loop, break the while loop.
|
||||
break
|
||||
|
|
|
@ -37,15 +37,15 @@ def project_push(project_dir: Path, remote: str):
|
|||
remote = config["remotes"][remote]
|
||||
storage = RemoteStorage(project_dir, remote)
|
||||
for cmd in config.get("commands", []):
|
||||
logger.debug(f"CMD: cmd['name']")
|
||||
logger.debug("CMD: %s", cmd["name"])
|
||||
deps = [project_dir / dep for dep in cmd.get("deps", [])]
|
||||
if any(not dep.exists() for dep in deps):
|
||||
logger.debug(f"Dependency missing. Skipping {cmd['name']} outputs")
|
||||
logger.debug("Dependency missing. Skipping %s outputs", cmd["name"])
|
||||
continue
|
||||
cmd_hash = get_command_hash(
|
||||
"", "", [project_dir / dep for dep in cmd.get("deps", [])], cmd["script"]
|
||||
)
|
||||
logger.debug(f"CMD_HASH: {cmd_hash}")
|
||||
logger.debug("CMD_HASH: %s", cmd_hash)
|
||||
for output_path in cmd.get("outputs", []):
|
||||
output_loc = project_dir / output_path
|
||||
if output_loc.exists() and _is_not_empty_dir(output_loc):
|
||||
|
@ -55,7 +55,7 @@ def project_push(project_dir: Path, remote: str):
|
|||
content_hash=get_content_hash(output_loc),
|
||||
)
|
||||
logger.debug(
|
||||
f"URL: {url} for output {output_path} with cmd_hash {cmd_hash}"
|
||||
"URL: %s for output %s with cmd_hash %s", url, output_path, cmd_hash
|
||||
)
|
||||
yield output_path, url
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ from typing import Optional, List, Dict, Sequence, Any, Iterable, Tuple
|
|||
import os.path
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
from wasabi import msg
|
||||
from wasabi.util import locale_escape
|
||||
import sys
|
||||
|
@ -331,6 +330,7 @@ def _check_requirements(requirements: List[str]) -> Tuple[bool, bool]:
|
|||
RETURNS (Tuple[bool, bool]): Whether (1) any packages couldn't be imported, (2) any packages with version conflicts
|
||||
exist.
|
||||
"""
|
||||
import pkg_resources
|
||||
|
||||
failed_pkgs_msgs: List[str] = []
|
||||
conflicting_pkgs_msgs: List[str] = []
|
||||
|
|
|
@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and
|
|||
can help generate the best possible configuration, given a user's requirements. #}
|
||||
{%- set use_transformer = hardware != "cpu" and transformer_data -%}
|
||||
{%- set transformer = transformer_data[optimize] if use_transformer else {} -%}
|
||||
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "trainable_lemmatizer"] -%}
|
||||
{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%}
|
||||
[paths]
|
||||
train = null
|
||||
dev = null
|
||||
|
@ -24,8 +24,11 @@ gpu_allocator = null
|
|||
lang = "{{ lang }}"
|
||||
{%- set has_textcat = ("textcat" in components or "textcat_multilabel" in components) -%}
|
||||
{%- set with_accuracy = optimize == "accuracy" -%}
|
||||
{%- set has_accurate_textcat = has_textcat and with_accuracy -%}
|
||||
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "trainable_lemmatizer" in components or "entity_linker" in components or has_accurate_textcat) -%}
|
||||
{# The BOW textcat doesn't need a source of features, so it can omit the
|
||||
tok2vec/transformer. #}
|
||||
{%- set with_accuracy_or_transformer = (use_transformer or with_accuracy) -%}
|
||||
{%- set textcat_needs_features = has_textcat and with_accuracy_or_transformer -%}
|
||||
{%- if ("tagger" in components or "morphologizer" in components or "parser" in components or "ner" in components or "spancat" in components or "spancat_singlelabel" in components or "trainable_lemmatizer" in components or "entity_linker" in components or textcat_needs_features) -%}
|
||||
{%- set full_pipeline = ["transformer" if use_transformer else "tok2vec"] + components -%}
|
||||
{%- else -%}
|
||||
{%- set full_pipeline = components -%}
|
||||
|
@ -156,6 +159,36 @@ grad_factor = 1.0
|
|||
sizes = [1,2,3]
|
||||
{% endif -%}
|
||||
|
||||
{% if "spancat_singlelabel" in components %}
|
||||
[components.spancat_singlelabel]
|
||||
factory = "spancat_singlelabel"
|
||||
negative_weight = 1.0
|
||||
allow_overlap = true
|
||||
scorer = {"@scorers":"spacy.spancat_scorer.v1"}
|
||||
spans_key = "sc"
|
||||
|
||||
[components.spancat_singlelabel.model]
|
||||
@architectures = "spacy.SpanCategorizer.v1"
|
||||
|
||||
[components.spancat_singlelabel.model.reducer]
|
||||
@layers = "spacy.mean_max_reducer.v1"
|
||||
hidden_size = 128
|
||||
|
||||
[components.spancat_singlelabel.model.scorer]
|
||||
@layers = "Softmax.v2"
|
||||
|
||||
[components.spancat_singlelabel.model.tok2vec]
|
||||
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||
grad_factor = 1.0
|
||||
|
||||
[components.spancat_singlelabel.model.tok2vec.pooling]
|
||||
@layers = "reduce_mean.v1"
|
||||
|
||||
[components.spancat_singlelabel.suggester]
|
||||
@misc = "spacy.ngram_suggester.v1"
|
||||
sizes = [1,2,3]
|
||||
{% endif %}
|
||||
|
||||
{% if "trainable_lemmatizer" in components -%}
|
||||
[components.trainable_lemmatizer]
|
||||
factory = "trainable_lemmatizer"
|
||||
|
@ -221,10 +254,16 @@ no_output_layer = false
|
|||
|
||||
{% else -%}
|
||||
[components.textcat.model]
|
||||
@architectures = "spacy.TextCatBOW.v2"
|
||||
@architectures = "spacy.TextCatCNN.v2"
|
||||
exclusive_classes = true
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
nO = null
|
||||
|
||||
[components.textcat.model.tok2vec]
|
||||
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||
grad_factor = 1.0
|
||||
|
||||
[components.textcat.model.tok2vec.pooling]
|
||||
@layers = "reduce_mean.v1"
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
|
@ -252,10 +291,16 @@ no_output_layer = false
|
|||
|
||||
{% else -%}
|
||||
[components.textcat_multilabel.model]
|
||||
@architectures = "spacy.TextCatBOW.v2"
|
||||
@architectures = "spacy.TextCatCNN.v2"
|
||||
exclusive_classes = false
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
nO = null
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec]
|
||||
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||
grad_factor = 1.0
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec.pooling]
|
||||
@layers = "reduce_mean.v1"
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
|
@ -286,6 +331,7 @@ maxout_pieces = 3
|
|||
{% if "morphologizer" in components %}
|
||||
[components.morphologizer]
|
||||
factory = "morphologizer"
|
||||
label_smoothing = 0.05
|
||||
|
||||
[components.morphologizer.model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
|
@ -299,6 +345,7 @@ width = ${components.tok2vec.model.encode.width}
|
|||
{% if "tagger" in components %}
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
label_smoothing = 0.05
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
|
@ -374,6 +421,33 @@ width = ${components.tok2vec.model.encode.width}
|
|||
sizes = [1,2,3]
|
||||
{% endif %}
|
||||
|
||||
{% if "spancat_singlelabel" in components %}
|
||||
[components.spancat_singlelabel]
|
||||
factory = "spancat_singlelabel"
|
||||
negative_weight = 1.0
|
||||
allow_overlap = true
|
||||
scorer = {"@scorers":"spacy.spancat_scorer.v1"}
|
||||
spans_key = "sc"
|
||||
|
||||
[components.spancat_singlelabel.model]
|
||||
@architectures = "spacy.SpanCategorizer.v1"
|
||||
|
||||
[components.spancat_singlelabel.model.reducer]
|
||||
@layers = "spacy.mean_max_reducer.v1"
|
||||
hidden_size = 128
|
||||
|
||||
[components.spancat_singlelabel.model.scorer]
|
||||
@layers = "Softmax.v2"
|
||||
|
||||
[components.spancat_singlelabel.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
|
||||
[components.spancat_singlelabel.suggester]
|
||||
@misc = "spacy.ngram_suggester.v1"
|
||||
sizes = [1,2,3]
|
||||
{% endif %}
|
||||
|
||||
{% if "trainable_lemmatizer" in components -%}
|
||||
[components.trainable_lemmatizer]
|
||||
factory = "trainable_lemmatizer"
|
||||
|
|
|
@ -11,6 +11,7 @@ from .render import DependencyRenderer, EntityRenderer, SpanRenderer
|
|||
from ..tokens import Doc, Span
|
||||
from ..errors import Errors, Warnings
|
||||
from ..util import is_in_jupyter
|
||||
from ..util import find_available_port
|
||||
|
||||
|
||||
_html = {}
|
||||
|
@ -36,7 +37,7 @@ def render(
|
|||
jupyter (bool): Override Jupyter auto-detection.
|
||||
options (dict): Visualiser-specific options, e.g. colors.
|
||||
manual (bool): Don't parse `Doc` and instead expect a dict/list of dicts.
|
||||
RETURNS (str): Rendered HTML markup.
|
||||
RETURNS (str): Rendered SVG or HTML markup.
|
||||
|
||||
DOCS: https://spacy.io/api/top-level#displacy.render
|
||||
USAGE: https://spacy.io/usage/visualizers
|
||||
|
@ -82,6 +83,7 @@ def serve(
|
|||
manual: bool = False,
|
||||
port: int = 5000,
|
||||
host: str = "0.0.0.0",
|
||||
auto_select_port: bool = False,
|
||||
) -> None:
|
||||
"""Serve displaCy visualisation.
|
||||
|
||||
|
@ -93,12 +95,15 @@ def serve(
|
|||
manual (bool): Don't parse `Doc` and instead expect a dict/list of dicts.
|
||||
port (int): Port to serve visualisation.
|
||||
host (str): Host to serve visualisation.
|
||||
auto_select_port (bool): Automatically select a port if the specified port is in use.
|
||||
|
||||
DOCS: https://spacy.io/api/top-level#displacy.serve
|
||||
USAGE: https://spacy.io/usage/visualizers
|
||||
"""
|
||||
from wsgiref import simple_server
|
||||
|
||||
port = find_available_port(port, host, auto_select_port)
|
||||
|
||||
if is_in_jupyter():
|
||||
warnings.warn(Warnings.W011)
|
||||
render(docs, style=style, page=page, minify=minify, options=options, manual=manual)
|
||||
|
@ -120,13 +125,17 @@ def app(environ, start_response):
|
|||
return [res]
|
||||
|
||||
|
||||
def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
|
||||
def parse_deps(
|
||||
orig_doc: Union[Doc, Span], options: Dict[str, Any] = {}
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate dependency parse in {'words': [], 'arcs': []} format.
|
||||
|
||||
orig_doc (Doc): Document to parse.
|
||||
orig_doc (Union[Doc, Span]): Document to parse.
|
||||
options (Dict[str, Any]): Dependency parse specific visualisation options.
|
||||
RETURNS (dict): Generated dependency parse keyed by words and arcs.
|
||||
"""
|
||||
if isinstance(orig_doc, Span):
|
||||
orig_doc = orig_doc.as_doc()
|
||||
doc = Doc(orig_doc.vocab).from_bytes(
|
||||
orig_doc.to_bytes(exclude=["user_data", "user_hooks"])
|
||||
)
|
||||
|
|
|
@ -94,7 +94,7 @@ class SpanRenderer:
|
|||
parsed (list): Dependency parses to render.
|
||||
page (bool): Render parses wrapped as full HTML page.
|
||||
minify (bool): Minify HTML markup.
|
||||
RETURNS (str): Rendered HTML markup.
|
||||
RETURNS (str): Rendered SVG or HTML markup.
|
||||
"""
|
||||
rendered = []
|
||||
for i, p in enumerate(parsed):
|
||||
|
@ -510,7 +510,7 @@ class EntityRenderer:
|
|||
parsed (list): Dependency parses to render.
|
||||
page (bool): Render parses wrapped as full HTML page.
|
||||
minify (bool): Minify HTML markup.
|
||||
RETURNS (str): Rendered HTML markup.
|
||||
RETURNS (str): Rendered SVG or HTML markup.
|
||||
"""
|
||||
rendered = []
|
||||
for i, p in enumerate(parsed):
|
||||
|
|
|
@ -214,6 +214,7 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
"is a Cython extension type.")
|
||||
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
|
||||
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
||||
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
||||
|
||||
|
||||
class Errors(metaclass=ErrorsWithCodes):
|
||||
|
@ -443,8 +444,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
|
||||
"exceed 1, but found {sum}.")
|
||||
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||
E139 = ("Knowledge base for component '{name}' is empty. Use the methods "
|
||||
"`kb.add_entity` and `kb.add_alias` to add entries.")
|
||||
E139 = ("Knowledge base for component '{name}' is empty.")
|
||||
E140 = ("The list of entities, prior probabilities and entity vectors "
|
||||
"should be of equal length.")
|
||||
E141 = ("Entity vectors should be of length {required} instead of the "
|
||||
|
@ -549,6 +549,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"during training, make sure to include it in 'annotating components'")
|
||||
|
||||
# New errors added in v3.x
|
||||
E850 = ("The PretrainVectors objective currently only supports default or "
|
||||
"floret vectors, not {mode} vectors.")
|
||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||
|
@ -961,6 +963,12 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E1045 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
|
||||
"knowledge base, use `InMemoryLookupKB`.")
|
||||
E1047 = ("`find_threshold()` only supports components with a `scorer` attribute.")
|
||||
E1048 = ("Got '{unexpected}' as console progress bar type, but expected one of the following: {expected}")
|
||||
E1049 = ("No available port found for displaCy on host {host}. Please specify an available port "
|
||||
"with `displacy.serve(doc, port=port)`")
|
||||
E1050 = ("Port {port} is already in use. Please specify an available port with `displacy.serve(doc, port=port)` "
|
||||
"or use `auto_select_port=True` to pick an available port automatically.")
|
||||
E1051 = ("'allow_overlap' can only be False when max_positive is 1, but found 'max_positive': {max_positive}.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -25,7 +25,7 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
"""An `InMemoryLookupKB` instance stores unique identifiers for entities and their textual aliases,
|
||||
to support entity linking of named entities to real-world concepts.
|
||||
|
||||
DOCS: https://spacy.io/api/kb_in_memory
|
||||
DOCS: https://spacy.io/api/inmemorylookupkb
|
||||
"""
|
||||
|
||||
def __init__(self, Vocab vocab, entity_vector_length):
|
||||
|
@ -46,6 +46,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
|
|||
self._alias_index = PreshMap(nr_aliases + 1)
|
||||
self._aliases_table = alias_vec(nr_aliases + 1)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self) == 0
|
||||
|
||||
@classmethod
|
||||
def generate_from_disk(
|
||||
cls, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
STOP_WORDS = set(
|
||||
"""
|
||||
aan af al alle alles allebei alleen allen als altijd ander anders andere anderen aangaangde aangezien achter achterna
|
||||
aan af al alle alles allebei alleen allen als altijd ander anders andere anderen aangaande aangezien achter achterna
|
||||
afgelopen aldus alhoewel anderzijds
|
||||
|
||||
ben bij bijna bijvoorbeeld behalve beide beiden beneden bent bepaald beter betere betreffende binnen binnenin boven
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from .stop_words import STOP_WORDS
|
||||
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
|
||||
from .lex_attrs import LEX_ATTRS
|
||||
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
|
||||
from ...language import Language, BaseDefaults
|
||||
|
||||
|
||||
class SerbianDefaults(BaseDefaults):
|
||||
tokenizer_exceptions = TOKENIZER_EXCEPTIONS
|
||||
infixes = TOKENIZER_INFIXES
|
||||
suffixes = TOKENIZER_SUFFIXES
|
||||
lex_attr_getters = LEX_ATTRS
|
||||
stop_words = STOP_WORDS
|
||||
|
||||
|
|
36
spacy/lang/sr/punctuation.py
Normal file
36
spacy/lang/sr/punctuation.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from ..char_classes import LIST_ELLIPSES, LIST_ICONS, LIST_PUNCT, LIST_QUOTES
|
||||
from ..char_classes import CURRENCY, UNITS, PUNCT
|
||||
from ..char_classes import CONCAT_QUOTES, ALPHA, ALPHA_LOWER, ALPHA_UPPER
|
||||
|
||||
|
||||
_infixes = (
|
||||
LIST_ELLIPSES
|
||||
+ LIST_ICONS
|
||||
+ [
|
||||
r"(?<=[0-9])[+\-\*^](?=[0-9-])",
|
||||
r"(?<=[{al}{q}])\.(?=[{au}{q}])".format(
|
||||
al=ALPHA_LOWER, au=ALPHA_UPPER, q=CONCAT_QUOTES
|
||||
),
|
||||
r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=ALPHA),
|
||||
]
|
||||
)
|
||||
|
||||
_suffixes = (
|
||||
LIST_PUNCT
|
||||
+ LIST_ELLIPSES
|
||||
+ LIST_QUOTES
|
||||
+ LIST_ICONS
|
||||
+ [
|
||||
r"(?<=[0-9])\+",
|
||||
r"(?<=°[FfCcKk])\.",
|
||||
r"(?<=[0-9])(?:{c})".format(c=CURRENCY),
|
||||
r"(?<=[0-9])(?:{u})".format(u=UNITS),
|
||||
r"(?<=[{a}{e}{p}(?:{q})])\.".format(
|
||||
a=ALPHA, e=r"%²\-\+", q=CONCAT_QUOTES, p=PUNCT
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
TOKENIZER_INFIXES = _infixes
|
||||
TOKENIZER_SUFFIXES = _suffixes
|
|
@ -6,10 +6,7 @@ from .lex_attrs import LEX_ATTRS
|
|||
from .syntax_iterators import SYNTAX_ITERATORS
|
||||
from ...language import Language, BaseDefaults
|
||||
from ...pipeline import Lemmatizer
|
||||
|
||||
|
||||
# Punctuation stolen from Danish
|
||||
from ..da.punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
|
||||
from .punctuation import TOKENIZER_INFIXES, TOKENIZER_SUFFIXES
|
||||
|
||||
|
||||
class SwedishDefaults(BaseDefaults):
|
||||
|
|
33
spacy/lang/sv/punctuation.py
Normal file
33
spacy/lang/sv/punctuation.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from ..char_classes import LIST_ELLIPSES, LIST_ICONS
|
||||
from ..char_classes import CONCAT_QUOTES, ALPHA, ALPHA_LOWER, ALPHA_UPPER
|
||||
from ..punctuation import TOKENIZER_SUFFIXES
|
||||
|
||||
|
||||
_quotes = CONCAT_QUOTES.replace("'", "")
|
||||
|
||||
_infixes = (
|
||||
LIST_ELLIPSES
|
||||
+ LIST_ICONS
|
||||
+ [
|
||||
r"(?<=[{al}])\.(?=[{au}])".format(al=ALPHA_LOWER, au=ALPHA_UPPER),
|
||||
r"(?<=[{a}])[,!?](?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}])[<>=](?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}]):(?=[{a}])".format(a=ALPHA_UPPER),
|
||||
r"(?<=[{a}]),(?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}])([{q}\)\]\(\[])(?=[{a}])".format(a=ALPHA, q=_quotes),
|
||||
r"(?<=[{a}])--(?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}0-9])[<>=/](?=[{a}])".format(a=ALPHA),
|
||||
r"(?<=[{a}0-9]):(?=[{a}])".format(a=ALPHA_UPPER),
|
||||
]
|
||||
)
|
||||
|
||||
_suffixes = [
|
||||
suffix
|
||||
for suffix in TOKENIZER_SUFFIXES
|
||||
if suffix not in ["'s", "'S", "’s", "’S", r"\'"]
|
||||
]
|
||||
_suffixes += [r"(?<=[^sSxXzZ])\'"]
|
||||
|
||||
|
||||
TOKENIZER_INFIXES = _infixes
|
||||
TOKENIZER_SUFFIXES = _suffixes
|
|
@ -104,7 +104,7 @@ def create_tokenizer() -> Callable[["Language"], Tokenizer]:
|
|||
|
||||
@registry.misc("spacy.LookupsDataLoader.v1")
|
||||
def load_lookups_data(lang, tables):
|
||||
util.logger.debug(f"Loading lookups from spacy-lookups-data: {tables}")
|
||||
util.logger.debug("Loading lookups from spacy-lookups-data: %s", tables)
|
||||
lookups = load_lookups(lang=lang, tables=tables)
|
||||
return lookups
|
||||
|
||||
|
@ -1969,7 +1969,7 @@ class Language:
|
|||
pipe = self.get_pipe(pipe_name)
|
||||
pipe_cfg = self._pipe_configs[pipe_name]
|
||||
if listeners:
|
||||
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
|
||||
util.logger.debug("Replacing listeners of component '%s'", pipe_name)
|
||||
if len(list(listeners)) != len(pipe_listeners):
|
||||
# The number of listeners defined in the component model doesn't
|
||||
# match the listeners to replace, so we won't be able to update
|
||||
|
|
|
@ -25,7 +25,8 @@ class Lexeme:
|
|||
def orth_(self) -> str: ...
|
||||
@property
|
||||
def text(self) -> str: ...
|
||||
lower: str
|
||||
orth: int
|
||||
lower: int
|
||||
norm: int
|
||||
shape: int
|
||||
prefix: int
|
||||
|
|
|
@ -199,7 +199,7 @@ cdef class Lexeme:
|
|||
return self.orth_
|
||||
|
||||
property lower:
|
||||
"""RETURNS (str): Lowercase form of the lexeme."""
|
||||
"""RETURNS (uint64): Lowercase form of the lexeme."""
|
||||
def __get__(self):
|
||||
return self.c.lower
|
||||
|
||||
|
|
|
@ -82,8 +82,12 @@ cdef class DependencyMatcher:
|
|||
"$-": self._imm_left_sib,
|
||||
"$++": self._right_sib,
|
||||
"$--": self._left_sib,
|
||||
">+": self._imm_right_child,
|
||||
">-": self._imm_left_child,
|
||||
">++": self._right_child,
|
||||
">--": self._left_child,
|
||||
"<+": self._imm_right_parent,
|
||||
"<-": self._imm_left_parent,
|
||||
"<++": self._right_parent,
|
||||
"<--": self._left_parent,
|
||||
}
|
||||
|
@ -427,11 +431,33 @@ cdef class DependencyMatcher:
|
|||
def _left_sib(self, doc, node):
|
||||
return [doc[child.i] for child in doc[node].head.children if child.i < node]
|
||||
|
||||
def _imm_right_child(self, doc, node):
|
||||
for child in doc[node].rights:
|
||||
if child.i == node + 1:
|
||||
return [doc[child.i]]
|
||||
return []
|
||||
|
||||
def _imm_left_child(self, doc, node):
|
||||
for child in doc[node].lefts:
|
||||
if child.i == node - 1:
|
||||
return [doc[child.i]]
|
||||
return []
|
||||
|
||||
def _right_child(self, doc, node):
|
||||
return [doc[child.i] for child in doc[node].children if child.i > node]
|
||||
return [child for child in doc[node].rights]
|
||||
|
||||
def _left_child(self, doc, node):
|
||||
return [doc[child.i] for child in doc[node].children if child.i < node]
|
||||
return [child for child in doc[node].lefts]
|
||||
|
||||
def _imm_right_parent(self, doc, node):
|
||||
if doc[node].head.i == node + 1:
|
||||
return [doc[node].head]
|
||||
return []
|
||||
|
||||
def _imm_left_parent(self, doc, node):
|
||||
if doc[node].head.i == node - 1:
|
||||
return [doc[node].head]
|
||||
return []
|
||||
|
||||
def _right_parent(self, doc, node):
|
||||
if doc[node].head.i > node:
|
||||
|
|
|
@ -4,6 +4,8 @@ from libc.stdint cimport int64_t
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from ..util import registry
|
||||
|
||||
|
||||
cdef extern from "polyleven.c":
|
||||
int64_t polyleven(PyObject *o1, PyObject *o2, int64_t k)
|
||||
|
@ -13,3 +15,18 @@ cpdef int64_t levenshtein(a: str, b: str, k: Optional[int] = None):
|
|||
if k is None:
|
||||
k = -1
|
||||
return polyleven(<PyObject*>a, <PyObject*>b, k)
|
||||
|
||||
|
||||
cpdef bint levenshtein_compare(input_text: str, pattern_text: str, fuzzy: int = -1):
|
||||
if fuzzy >= 0:
|
||||
max_edits = fuzzy
|
||||
else:
|
||||
# allow at least two edits (to allow at least one transposition) and up
|
||||
# to 30% of the pattern string length
|
||||
max_edits = max(2, round(0.3 * len(pattern_text)))
|
||||
return levenshtein(input_text, pattern_text, max_edits) <= max_edits
|
||||
|
||||
|
||||
@registry.misc("spacy.levenshtein_compare.v1")
|
||||
def make_levenshtein_compare():
|
||||
return levenshtein_compare
|
||||
|
|
|
@ -77,3 +77,4 @@ cdef class Matcher:
|
|||
cdef public object _extensions
|
||||
cdef public object _extra_predicates
|
||||
cdef public object _seen_attrs
|
||||
cdef public object _fuzzy_compare
|
||||
|
|
|
@ -5,7 +5,12 @@ from ..vocab import Vocab
|
|||
from ..tokens import Doc, Span
|
||||
|
||||
class Matcher:
|
||||
def __init__(self, vocab: Vocab, validate: bool = ...) -> None: ...
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Vocab,
|
||||
validate: bool = ...,
|
||||
fuzzy_compare: Callable[[str, str, int], bool] = ...,
|
||||
) -> None: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __contains__(self, key: str) -> bool: ...
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
# cython: binding=True, infer_types=True, profile=True
|
||||
from typing import List, Iterable
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
|
@ -20,10 +20,12 @@ from ..tokens.token cimport Token
|
|||
from ..tokens.morphanalysis cimport MorphAnalysis
|
||||
from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA, MORPH, ENT_IOB
|
||||
|
||||
from .levenshtein import levenshtein_compare
|
||||
from ..schemas import validate_token_pattern
|
||||
from ..errors import Errors, MatchPatternError, Warnings
|
||||
from ..strings import get_string_id
|
||||
from ..attrs import IDS
|
||||
from ..util import registry
|
||||
|
||||
|
||||
DEF PADDING = 5
|
||||
|
@ -36,11 +38,13 @@ cdef class Matcher:
|
|||
USAGE: https://spacy.io/usage/rule-based-matching
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, validate=True):
|
||||
def __init__(self, vocab, validate=True, *, fuzzy_compare=levenshtein_compare):
|
||||
"""Create the Matcher.
|
||||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
documents the matcher will operate on.
|
||||
validate (bool): Validate all patterns added to this matcher.
|
||||
fuzzy_compare (Callable[[str, str, int], bool]): The comparison method
|
||||
for the FUZZY operators.
|
||||
"""
|
||||
self._extra_predicates = []
|
||||
self._patterns = {}
|
||||
|
@ -51,9 +55,10 @@ cdef class Matcher:
|
|||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
self.validate = validate
|
||||
self._fuzzy_compare = fuzzy_compare
|
||||
|
||||
def __reduce__(self):
|
||||
data = (self.vocab, self._patterns, self._callbacks)
|
||||
data = (self.vocab, self._patterns, self._callbacks, self.validate, self._fuzzy_compare)
|
||||
return (unpickle_matcher, data, None, None)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -128,7 +133,7 @@ cdef class Matcher:
|
|||
for pattern in patterns:
|
||||
try:
|
||||
specs = _preprocess_pattern(pattern, self.vocab,
|
||||
self._extensions, self._extra_predicates)
|
||||
self._extensions, self._extra_predicates, self._fuzzy_compare)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
for spec in specs:
|
||||
for attr, _ in spec[1]:
|
||||
|
@ -326,8 +331,8 @@ cdef class Matcher:
|
|||
return key
|
||||
|
||||
|
||||
def unpickle_matcher(vocab, patterns, callbacks):
|
||||
matcher = Matcher(vocab)
|
||||
def unpickle_matcher(vocab, patterns, callbacks, validate, fuzzy_compare):
|
||||
matcher = Matcher(vocab, validate=validate, fuzzy_compare=fuzzy_compare)
|
||||
for key, pattern in patterns.items():
|
||||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, pattern, on_match=callback)
|
||||
|
@ -754,7 +759,7 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
|||
return id_attr.value
|
||||
|
||||
|
||||
def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
|
||||
def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates, fuzzy_compare):
|
||||
"""This function interprets the pattern, converting the various bits of
|
||||
syntactic sugar before we compile it into a struct with init_pattern.
|
||||
|
||||
|
@ -781,7 +786,7 @@ def _preprocess_pattern(token_specs, vocab, extensions_table, extra_predicates):
|
|||
ops = _get_operators(spec)
|
||||
attr_values = _get_attr_values(spec, string_store)
|
||||
extensions = _get_extensions(spec, string_store, extensions_table)
|
||||
predicates = _get_extra_predicates(spec, extra_predicates, vocab)
|
||||
predicates = _get_extra_predicates(spec, extra_predicates, vocab, fuzzy_compare)
|
||||
for op in ops:
|
||||
tokens.append((op, list(attr_values), list(extensions), list(predicates), token_idx))
|
||||
return tokens
|
||||
|
@ -823,19 +828,53 @@ def _get_attr_values(spec, string_store):
|
|||
return attr_values
|
||||
|
||||
|
||||
def _predicate_cache_key(attr, predicate, value, *, regex=False, fuzzy=None):
|
||||
# tuple order affects performance
|
||||
return (attr, regex, fuzzy, predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
|
||||
|
||||
# These predicate helper classes are used to match the REGEX, IN, >= etc
|
||||
# extensions to the matcher introduced in #3173.
|
||||
|
||||
class _FuzzyPredicate:
|
||||
operators = ("FUZZY", "FUZZY1", "FUZZY2", "FUZZY3", "FUZZY4", "FUZZY5",
|
||||
"FUZZY6", "FUZZY7", "FUZZY8", "FUZZY9")
|
||||
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
|
||||
regex=False, fuzzy=None, fuzzy_compare=None):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = value
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
if self.predicate not in self.operators:
|
||||
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
||||
fuzz = self.predicate[len("FUZZY"):] # number after prefix
|
||||
self.fuzzy = int(fuzz) if fuzz else -1
|
||||
self.fuzzy_compare = fuzzy_compare
|
||||
self.key = _predicate_cache_key(self.attr, self.predicate, value, fuzzy=self.fuzzy)
|
||||
|
||||
def __call__(self, Token token):
|
||||
if self.is_extension:
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)]
|
||||
if self.value == value:
|
||||
return True
|
||||
return self.fuzzy_compare(value, self.value, self.fuzzy)
|
||||
|
||||
|
||||
class _RegexPredicate:
|
||||
operators = ("REGEX",)
|
||||
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
|
||||
regex=False, fuzzy=None, fuzzy_compare=None):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = re.compile(value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
self.key = _predicate_cache_key(self.attr, self.predicate, value)
|
||||
if self.predicate not in self.operators:
|
||||
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
||||
|
||||
|
@ -850,18 +889,28 @@ class _RegexPredicate:
|
|||
class _SetPredicate:
|
||||
operators = ("IN", "NOT_IN", "IS_SUBSET", "IS_SUPERSET", "INTERSECTS")
|
||||
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
|
||||
regex=False, fuzzy=None, fuzzy_compare=None):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.vocab = vocab
|
||||
self.regex = regex
|
||||
self.fuzzy = fuzzy
|
||||
self.fuzzy_compare = fuzzy_compare
|
||||
if self.attr == MORPH:
|
||||
# normalize morph strings
|
||||
self.value = set(self.vocab.morphology.add(v) for v in value)
|
||||
else:
|
||||
self.value = set(get_string_id(v) for v in value)
|
||||
if self.regex:
|
||||
self.value = set(re.compile(v) for v in value)
|
||||
elif self.fuzzy is not None:
|
||||
# add to string store
|
||||
self.value = set(self.vocab.strings.add(v) for v in value)
|
||||
else:
|
||||
self.value = set(get_string_id(v) for v in value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
self.key = _predicate_cache_key(self.attr, self.predicate, value, regex=self.regex, fuzzy=self.fuzzy)
|
||||
if self.predicate not in self.operators:
|
||||
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
||||
|
||||
|
@ -889,9 +938,29 @@ class _SetPredicate:
|
|||
return False
|
||||
|
||||
if self.predicate == "IN":
|
||||
return value in self.value
|
||||
if self.regex:
|
||||
value = self.vocab.strings[value]
|
||||
return any(bool(v.search(value)) for v in self.value)
|
||||
elif self.fuzzy is not None:
|
||||
value = self.vocab.strings[value]
|
||||
return any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy)
|
||||
for v in self.value)
|
||||
elif value in self.value:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
elif self.predicate == "NOT_IN":
|
||||
return value not in self.value
|
||||
if self.regex:
|
||||
value = self.vocab.strings[value]
|
||||
return not any(bool(v.search(value)) for v in self.value)
|
||||
elif self.fuzzy is not None:
|
||||
value = self.vocab.strings[value]
|
||||
return not any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy)
|
||||
for v in self.value)
|
||||
elif value in self.value:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
elif self.predicate == "IS_SUBSET":
|
||||
return value <= self.value
|
||||
elif self.predicate == "IS_SUPERSET":
|
||||
|
@ -906,13 +975,14 @@ class _SetPredicate:
|
|||
class _ComparisonPredicate:
|
||||
operators = ("==", "!=", ">=", "<=", ">", "<")
|
||||
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
|
||||
regex=False, fuzzy=None, fuzzy_compare=None):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = value
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||
self.key = _predicate_cache_key(self.attr, self.predicate, value)
|
||||
if self.predicate not in self.operators:
|
||||
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
|
||||
|
||||
|
@ -935,7 +1005,7 @@ class _ComparisonPredicate:
|
|||
return value < self.value
|
||||
|
||||
|
||||
def _get_extra_predicates(spec, extra_predicates, vocab):
|
||||
def _get_extra_predicates(spec, extra_predicates, vocab, fuzzy_compare):
|
||||
predicate_types = {
|
||||
"REGEX": _RegexPredicate,
|
||||
"IN": _SetPredicate,
|
||||
|
@ -949,6 +1019,16 @@ def _get_extra_predicates(spec, extra_predicates, vocab):
|
|||
"<=": _ComparisonPredicate,
|
||||
">": _ComparisonPredicate,
|
||||
"<": _ComparisonPredicate,
|
||||
"FUZZY": _FuzzyPredicate,
|
||||
"FUZZY1": _FuzzyPredicate,
|
||||
"FUZZY2": _FuzzyPredicate,
|
||||
"FUZZY3": _FuzzyPredicate,
|
||||
"FUZZY4": _FuzzyPredicate,
|
||||
"FUZZY5": _FuzzyPredicate,
|
||||
"FUZZY6": _FuzzyPredicate,
|
||||
"FUZZY7": _FuzzyPredicate,
|
||||
"FUZZY8": _FuzzyPredicate,
|
||||
"FUZZY9": _FuzzyPredicate,
|
||||
}
|
||||
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
|
||||
output = []
|
||||
|
@ -966,22 +1046,47 @@ def _get_extra_predicates(spec, extra_predicates, vocab):
|
|||
attr = "ORTH"
|
||||
attr = IDS.get(attr.upper())
|
||||
if isinstance(value, dict):
|
||||
processed = False
|
||||
value_with_upper_keys = {k.upper(): v for k, v in value.items()}
|
||||
for type_, cls in predicate_types.items():
|
||||
if type_ in value_with_upper_keys:
|
||||
predicate = cls(len(extra_predicates), attr, value_with_upper_keys[type_], type_, vocab=vocab)
|
||||
# Don't create a redundant predicates.
|
||||
# This helps with efficiency, as we're caching the results.
|
||||
if predicate.key in seen_predicates:
|
||||
output.append(seen_predicates[predicate.key])
|
||||
else:
|
||||
extra_predicates.append(predicate)
|
||||
output.append(predicate.i)
|
||||
seen_predicates[predicate.key] = predicate.i
|
||||
processed = True
|
||||
if not processed:
|
||||
warnings.warn(Warnings.W035.format(pattern=value))
|
||||
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
|
||||
extra_predicates, seen_predicates, fuzzy_compare=fuzzy_compare))
|
||||
return output
|
||||
|
||||
|
||||
def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
|
||||
extra_predicates, seen_predicates, regex=False, fuzzy=None, fuzzy_compare=None):
|
||||
output = []
|
||||
for type_, value in value_dict.items():
|
||||
type_ = type_.upper()
|
||||
cls = predicate_types.get(type_)
|
||||
if cls is None:
|
||||
warnings.warn(Warnings.W035.format(pattern=value_dict))
|
||||
# ignore unrecognized predicate type
|
||||
continue
|
||||
elif cls == _RegexPredicate:
|
||||
if isinstance(value, dict):
|
||||
# add predicates inside regex operator
|
||||
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
|
||||
extra_predicates, seen_predicates,
|
||||
regex=True))
|
||||
continue
|
||||
elif cls == _FuzzyPredicate:
|
||||
if isinstance(value, dict):
|
||||
# add predicates inside fuzzy operator
|
||||
fuzz = type_[len("FUZZY"):] # number after prefix
|
||||
fuzzy_val = int(fuzz) if fuzz else -1
|
||||
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
|
||||
extra_predicates, seen_predicates,
|
||||
fuzzy=fuzzy_val, fuzzy_compare=fuzzy_compare))
|
||||
continue
|
||||
predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab,
|
||||
regex=regex, fuzzy=fuzzy, fuzzy_compare=fuzzy_compare)
|
||||
# Don't create redundant predicates.
|
||||
# This helps with efficiency, as we're caching the results.
|
||||
if predicate.key in seen_predicates:
|
||||
output.append(seen_predicates[predicate.key])
|
||||
else:
|
||||
extra_predicates.append(predicate)
|
||||
output.append(predicate.i)
|
||||
seen_predicates[predicate.key] = predicate.i
|
||||
return output
|
||||
|
||||
|
||||
|
@ -992,7 +1097,7 @@ def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
|
|||
if isinstance(value, dict):
|
||||
for type_, cls in predicate_types.items():
|
||||
if type_ in value:
|
||||
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
||||
key = _predicate_cache_key(attr, type_, value[type_])
|
||||
if key in seen_predicates:
|
||||
output.append(seen_predicates[key])
|
||||
else:
|
||||
|
|
|
@ -89,6 +89,14 @@ def load_kb(
|
|||
return kb_from_file
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v2")
|
||||
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
|
||||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.misc("spacy.EmptyKB.v1")
|
||||
def empty_kb(
|
||||
entity_vector_length: int,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING, cast
|
||||
from thinc.types import Floats2d
|
||||
from thinc.types import Floats2d, Ints1d
|
||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||
from thinc.api import MultiSoftmax, list2array
|
||||
from thinc.api import to_categorical, CosineDistance, L2Distance
|
||||
|
@ -7,7 +7,8 @@ from thinc.loss import Loss
|
|||
|
||||
from ...util import registry, OOV_RANK
|
||||
from ...errors import Errors
|
||||
from ...attrs import ID
|
||||
from ...attrs import ID, ORTH
|
||||
from ...vectors import Mode as VectorsMode
|
||||
|
||||
import numpy
|
||||
from functools import partial
|
||||
|
@ -67,14 +68,23 @@ def get_vectors_loss(ops, docs, prediction, distance):
|
|||
"""Compute a loss based on a distance between the documents' vectors and
|
||||
the prediction.
|
||||
"""
|
||||
# The simplest way to implement this would be to vstack the
|
||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||
# Instead we fetch the index into the vectors table for each of our tokens,
|
||||
# and look them up all at once. This prevents data copying.
|
||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = docs[0].vocab.vectors.data[ids]
|
||||
target[ids == OOV_RANK] = 0
|
||||
d_target, loss = distance(prediction, target)
|
||||
vocab = docs[0].vocab
|
||||
if vocab.vectors.mode == VectorsMode.default:
|
||||
# The simplest way to implement this would be to vstack the
|
||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||
# Instead we fetch the index into the vectors table for each of our
|
||||
# tokens, and look them up all at once. This prevents data copying.
|
||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||
target = docs[0].vocab.vectors.data[ids]
|
||||
target[ids == OOV_RANK] = 0
|
||||
d_target, loss = distance(prediction, target)
|
||||
elif vocab.vectors.mode == VectorsMode.floret:
|
||||
keys = ops.flatten([cast(Ints1d, doc.to_array(ORTH)) for doc in docs])
|
||||
target = vocab.vectors.get_batch(keys)
|
||||
target = ops.as_contig(target)
|
||||
d_target, loss = distance(prediction, target)
|
||||
else:
|
||||
raise ValueError(Errors.E850.format(mode=vocab.vectors.mode))
|
||||
return loss, d_target
|
||||
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ from itertools import islice
|
|||
import numpy as np
|
||||
|
||||
import srsly
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d
|
||||
from thinc.api import Config, Model, SequenceCategoricalCrossentropy, NumpyOps
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
|
||||
from ._edit_tree_internals.edit_trees import EditTrees
|
||||
from ._edit_tree_internals.schemas import validate_edit_tree
|
||||
|
@ -20,6 +20,10 @@ from ..vocab import Vocab
|
|||
from .. import util
|
||||
|
||||
|
||||
# The cutoff value of *top_k* above which an alternative method is used to process guesses.
|
||||
TOP_K_GUARDRAIL = 20
|
||||
|
||||
|
||||
default_model_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
|
@ -115,6 +119,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
|
||||
self.cfg: Dict[str, Any] = {"labels": []}
|
||||
self.scorer = scorer
|
||||
self.numpy_ops = NumpyOps()
|
||||
|
||||
def get_loss(
|
||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||
|
@ -128,7 +133,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
for (predicted, gold_lemma) in zip(
|
||||
eg.predicted, eg.get_aligned("LEMMA", as_string=True)
|
||||
):
|
||||
if gold_lemma is None:
|
||||
if gold_lemma is None or gold_lemma == "":
|
||||
label = -1
|
||||
else:
|
||||
tree_id = self.trees.add(predicted.text, gold_lemma)
|
||||
|
@ -144,6 +149,18 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
return float(loss), d_scores
|
||||
|
||||
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]:
|
||||
if self.top_k == 1:
|
||||
scores2guesses = self._scores2guesses_top_k_equals_1
|
||||
elif self.top_k <= TOP_K_GUARDRAIL:
|
||||
scores2guesses = self._scores2guesses_top_k_greater_1
|
||||
else:
|
||||
scores2guesses = self._scores2guesses_top_k_guardrail
|
||||
# The behaviour of *_scores2guesses_top_k_greater_1()* is efficient for values
|
||||
# of *top_k>1* that are likely to be useful when the edit tree lemmatizer is used
|
||||
# for its principal purpose of lemmatizing tokens. However, the code could also
|
||||
# be used for other purposes, and with very large values of *top_k* the method
|
||||
# becomes inefficient. In such cases, *_scores2guesses_top_k_guardrail()* is used
|
||||
# instead.
|
||||
n_docs = len(list(docs))
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -153,20 +170,52 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
return guesses
|
||||
scores = self.model.predict(docs)
|
||||
assert len(scores) == n_docs
|
||||
guesses = self._scores2guesses(docs, scores)
|
||||
guesses = scores2guesses(docs, scores)
|
||||
assert len(guesses) == n_docs
|
||||
return guesses
|
||||
|
||||
def _scores2guesses(self, docs, scores):
|
||||
def _scores2guesses_top_k_equals_1(self, docs, scores):
|
||||
guesses = []
|
||||
for doc, doc_scores in zip(docs, scores):
|
||||
if self.top_k == 1:
|
||||
doc_guesses = doc_scores.argmax(axis=1).reshape(-1, 1)
|
||||
else:
|
||||
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
|
||||
doc_guesses = doc_scores.argmax(axis=1)
|
||||
doc_guesses = self.numpy_ops.asarray(doc_guesses)
|
||||
|
||||
if not isinstance(doc_guesses, np.ndarray):
|
||||
doc_guesses = doc_guesses.get()
|
||||
doc_compat_guesses = []
|
||||
for i, token in enumerate(doc):
|
||||
tree_id = self.cfg["labels"][doc_guesses[i]]
|
||||
if self.trees.apply(tree_id, token.text) is not None:
|
||||
doc_compat_guesses.append(tree_id)
|
||||
else:
|
||||
doc_compat_guesses.append(-1)
|
||||
guesses.append(np.array(doc_compat_guesses))
|
||||
|
||||
return guesses
|
||||
|
||||
def _scores2guesses_top_k_greater_1(self, docs, scores):
|
||||
guesses = []
|
||||
top_k = min(self.top_k, len(self.labels))
|
||||
for doc, doc_scores in zip(docs, scores):
|
||||
doc_scores = self.numpy_ops.asarray(doc_scores)
|
||||
doc_compat_guesses = []
|
||||
for i, token in enumerate(doc):
|
||||
for _ in range(top_k):
|
||||
candidate = int(doc_scores[i].argmax())
|
||||
candidate_tree_id = self.cfg["labels"][candidate]
|
||||
if self.trees.apply(candidate_tree_id, token.text) is not None:
|
||||
doc_compat_guesses.append(candidate_tree_id)
|
||||
break
|
||||
doc_scores[i, candidate] = np.finfo(np.float32).min
|
||||
else:
|
||||
doc_compat_guesses.append(-1)
|
||||
guesses.append(np.array(doc_compat_guesses))
|
||||
|
||||
return guesses
|
||||
|
||||
def _scores2guesses_top_k_guardrail(self, docs, scores):
|
||||
guesses = []
|
||||
for doc, doc_scores in zip(docs, scores):
|
||||
doc_guesses = np.argsort(doc_scores)[..., : -self.top_k - 1 : -1]
|
||||
doc_guesses = self.numpy_ops.asarray(doc_guesses)
|
||||
|
||||
doc_compat_guesses = []
|
||||
for token, candidates in zip(doc, doc_guesses):
|
||||
|
|
|
@ -265,7 +265,7 @@ class EntityLinker(TrainablePipe):
|
|||
# Raise an error if the knowledge base is not initialized.
|
||||
if self.kb is None:
|
||||
raise ValueError(Errors.E1018.format(name=self.name))
|
||||
if len(self.kb) == 0:
|
||||
if hasattr(self.kb, "is_empty") and self.kb.is_empty():
|
||||
raise ValueError(Errors.E139.format(name=self.name))
|
||||
|
||||
def initialize(
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..errors import Errors, Warnings
|
|||
from ..util import ensure_path, to_disk, from_disk, SimpleFrozenList, registry
|
||||
from ..tokens import Doc, Span
|
||||
from ..matcher import Matcher, PhraseMatcher
|
||||
from ..matcher.levenshtein import levenshtein_compare
|
||||
from ..scorer import get_ner_prf
|
||||
|
||||
|
||||
|
@ -23,6 +24,7 @@ PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
|||
assigns=["doc.ents", "token.ent_type", "token.ent_iob"],
|
||||
default_config={
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite_ents": False,
|
||||
"ent_id_sep": DEFAULT_ENT_ID_SEP,
|
||||
|
@ -39,6 +41,7 @@ def make_entity_ruler(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
ent_id_sep: str,
|
||||
|
@ -48,6 +51,7 @@ def make_entity_ruler(
|
|||
nlp,
|
||||
name,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite_ents=overwrite_ents,
|
||||
ent_id_sep=ent_id_sep,
|
||||
|
@ -81,6 +85,7 @@ class EntityRuler(Pipe):
|
|||
name: str = "entity_ruler",
|
||||
*,
|
||||
phrase_matcher_attr: Optional[Union[int, str]] = None,
|
||||
matcher_fuzzy_compare: Callable = levenshtein_compare,
|
||||
validate: bool = False,
|
||||
overwrite_ents: bool = False,
|
||||
ent_id_sep: str = DEFAULT_ENT_ID_SEP,
|
||||
|
@ -99,7 +104,10 @@ class EntityRuler(Pipe):
|
|||
added. Used to disable the current entity ruler while creating
|
||||
phrase patterns with the nlp object.
|
||||
phrase_matcher_attr (int / str): Token attribute to match on, passed
|
||||
to the internal PhraseMatcher as `attr`
|
||||
to the internal PhraseMatcher as `attr`.
|
||||
matcher_fuzzy_compare (Callable): The fuzzy comparison method for the
|
||||
internal Matcher. Defaults to
|
||||
spacy.matcher.levenshtein.levenshtein_compare.
|
||||
validate (bool): Whether patterns should be validated, passed to
|
||||
Matcher and PhraseMatcher as `validate`
|
||||
patterns (iterable): Optional patterns to load in.
|
||||
|
@ -117,7 +125,10 @@ class EntityRuler(Pipe):
|
|||
self.token_patterns = defaultdict(list) # type: ignore
|
||||
self.phrase_patterns = defaultdict(list) # type: ignore
|
||||
self._validate = validate
|
||||
self.matcher = Matcher(nlp.vocab, validate=validate)
|
||||
self.matcher_fuzzy_compare = matcher_fuzzy_compare
|
||||
self.matcher = Matcher(
|
||||
nlp.vocab, validate=validate, fuzzy_compare=self.matcher_fuzzy_compare
|
||||
)
|
||||
self.phrase_matcher_attr = phrase_matcher_attr
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
nlp.vocab, attr=self.phrase_matcher_attr, validate=validate
|
||||
|
@ -337,7 +348,11 @@ class EntityRuler(Pipe):
|
|||
self.token_patterns = defaultdict(list)
|
||||
self.phrase_patterns = defaultdict(list)
|
||||
self._ent_ids = defaultdict(tuple)
|
||||
self.matcher = Matcher(self.nlp.vocab, validate=self._validate)
|
||||
self.matcher = Matcher(
|
||||
self.nlp.vocab,
|
||||
validate=self._validate,
|
||||
fuzzy_compare=self.matcher_fuzzy_compare,
|
||||
)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr, validate=self._validate
|
||||
)
|
||||
|
@ -431,7 +446,8 @@ class EntityRuler(Pipe):
|
|||
self.overwrite = cfg.get("overwrite", False)
|
||||
self.phrase_matcher_attr = cfg.get("phrase_matcher_attr", None)
|
||||
self.phrase_matcher = PhraseMatcher(
|
||||
self.nlp.vocab, attr=self.phrase_matcher_attr
|
||||
self.nlp.vocab,
|
||||
attr=self.phrase_matcher_attr,
|
||||
)
|
||||
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
|
||||
else:
|
||||
|
|
|
@ -52,7 +52,8 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"morphologizer",
|
||||
assigns=["token.morph", "token.pos"],
|
||||
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
|
||||
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False,
|
||||
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}, "label_smoothing": 0.0},
|
||||
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
|
||||
)
|
||||
def make_morphologizer(
|
||||
|
@ -61,9 +62,10 @@ def make_morphologizer(
|
|||
name: str,
|
||||
overwrite: bool,
|
||||
extend: bool,
|
||||
label_smoothing: float,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer)
|
||||
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, label_smoothing=label_smoothing, scorer=scorer)
|
||||
|
||||
|
||||
def morphologizer_score(examples, **kwargs):
|
||||
|
@ -94,6 +96,7 @@ class Morphologizer(Tagger):
|
|||
*,
|
||||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
extend: bool = BACKWARD_EXTEND,
|
||||
label_smoothing: float = 0.0,
|
||||
scorer: Optional[Callable] = morphologizer_score,
|
||||
):
|
||||
"""Initialize a morphologizer.
|
||||
|
@ -121,6 +124,7 @@ class Morphologizer(Tagger):
|
|||
"labels_pos": {},
|
||||
"overwrite": overwrite,
|
||||
"extend": extend,
|
||||
"label_smoothing": label_smoothing,
|
||||
}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
|
@ -270,7 +274,8 @@ class Morphologizer(Tagger):
|
|||
DOCS: https://spacy.io/api/morphologizer#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Morphologizer.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False,
|
||||
label_smoothing=self.cfg["label_smoothing"])
|
||||
truths = []
|
||||
for eg in examples:
|
||||
eg_truths = []
|
||||
|
|
|
@ -13,6 +13,7 @@ from ..util import ensure_path, SimpleFrozenList, registry
|
|||
from ..tokens import Doc, Span
|
||||
from ..scorer import Scorer
|
||||
from ..matcher import Matcher, PhraseMatcher
|
||||
from ..matcher.levenshtein import levenshtein_compare
|
||||
from .. import util
|
||||
|
||||
PatternType = Dict[str, Union[str, List[Dict[str, Any]]]]
|
||||
|
@ -28,6 +29,7 @@ DEFAULT_SPANS_KEY = "ruler"
|
|||
"overwrite_ents": False,
|
||||
"scorer": {"@scorers": "spacy.entity_ruler_scorer.v1"},
|
||||
"ent_id_sep": "__unused__",
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
"ents_f": 1.0,
|
||||
|
@ -40,6 +42,7 @@ def make_entity_ruler(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite_ents: bool,
|
||||
scorer: Optional[Callable],
|
||||
|
@ -57,6 +60,7 @@ def make_entity_ruler(
|
|||
annotate_ents=True,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=False,
|
||||
scorer=scorer,
|
||||
|
@ -72,6 +76,7 @@ def make_entity_ruler(
|
|||
"annotate_ents": False,
|
||||
"ents_filter": {"@misc": "spacy.first_longest_spans_filter.v1"},
|
||||
"phrase_matcher_attr": None,
|
||||
"matcher_fuzzy_compare": {"@misc": "spacy.levenshtein_compare.v1"},
|
||||
"validate": False,
|
||||
"overwrite": True,
|
||||
"scorer": {
|
||||
|
@ -94,6 +99,7 @@ def make_span_ruler(
|
|||
annotate_ents: bool,
|
||||
ents_filter: Callable[[Iterable[Span], Iterable[Span]], Iterable[Span]],
|
||||
phrase_matcher_attr: Optional[Union[int, str]],
|
||||
matcher_fuzzy_compare: Callable,
|
||||
validate: bool,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
|
@ -106,6 +112,7 @@ def make_span_ruler(
|
|||
annotate_ents=annotate_ents,
|
||||
ents_filter=ents_filter,
|
||||
phrase_matcher_attr=phrase_matcher_attr,
|
||||
matcher_fuzzy_compare=matcher_fuzzy_compare,
|
||||
validate=validate,
|
||||
overwrite=overwrite,
|
||||
scorer=scorer,
|
||||
|
@ -170,7 +177,7 @@ def prioritize_existing_ents_filter(
|
|||
|
||||
|
||||
@registry.misc("spacy.prioritize_existing_ents_filter.v1")
|
||||
def make_preverse_existing_ents_filter():
|
||||
def make_preserve_existing_ents_filter():
|
||||
return prioritize_existing_ents_filter
|
||||
|
||||
|
||||
|
@ -216,6 +223,7 @@ class SpanRuler(Pipe):
|
|||
[Iterable[Span], Iterable[Span]], Iterable[Span]
|
||||
] = util.filter_chain_spans,
|
||||
phrase_matcher_attr: Optional[Union[int, str]] = None,
|
||||
matcher_fuzzy_compare: Callable = levenshtein_compare,
|
||||
validate: bool = False,
|
||||
overwrite: bool = False,
|
||||
scorer: Optional[Callable] = partial(
|
||||
|
@ -246,6 +254,9 @@ class SpanRuler(Pipe):
|
|||
phrase_matcher_attr (Optional[Union[int, str]]): Token attribute to
|
||||
match on, passed to the internal PhraseMatcher as `attr`. Defaults
|
||||
to `None`.
|
||||
matcher_fuzzy_compare (Callable): The fuzzy comparison method for the
|
||||
internal Matcher. Defaults to
|
||||
spacy.matcher.levenshtein.levenshtein_compare.
|
||||
validate (bool): Whether patterns should be validated, passed to
|
||||
Matcher and PhraseMatcher as `validate`.
|
||||
overwrite (bool): Whether to remove any existing spans under this spans
|
||||
|
@ -266,6 +277,7 @@ class SpanRuler(Pipe):
|
|||
self.spans_filter = spans_filter
|
||||
self.ents_filter = ents_filter
|
||||
self.scorer = scorer
|
||||
self.matcher_fuzzy_compare = matcher_fuzzy_compare
|
||||
self._match_label_id_map: Dict[int, Dict[str, str]] = {}
|
||||
self.clear()
|
||||
|
||||
|
@ -451,7 +463,11 @@ class SpanRuler(Pipe):
|
|||
DOCS: https://spacy.io/api/spanruler#clear
|
||||
"""
|
||||
self._patterns: List[PatternType] = []
|
||||
self.matcher: Matcher = Matcher(self.nlp.vocab, validate=self.validate)
|
||||
self.matcher: Matcher = Matcher(
|
||||
self.nlp.vocab,
|
||||
validate=self.validate,
|
||||
fuzzy_compare=self.matcher_fuzzy_compare,
|
||||
)
|
||||
self.phrase_matcher: PhraseMatcher = PhraseMatcher(
|
||||
self.nlp.vocab,
|
||||
attr=self.phrase_matcher_attr,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any
|
||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d
|
||||
|
@ -43,7 +45,36 @@ maxout_pieces = 3
|
|||
depth = 4
|
||||
"""
|
||||
|
||||
spancat_singlelabel_default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanCategorizer.v1"
|
||||
scorer = {"@layers": "Softmax.v2"}
|
||||
|
||||
[model.reducer]
|
||||
@layers = spacy.mean_max_reducer.v1
|
||||
hidden_size = 128
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = 96
|
||||
rows = [5000, 1000, 2500, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = ${model.tok2vec.embed.width}
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
depth = 4
|
||||
"""
|
||||
|
||||
DEFAULT_SPANCAT_MODEL = Config().from_str(spancat_default_config)["model"]
|
||||
DEFAULT_SPANCAT_SINGLELABEL_MODEL = Config().from_str(
|
||||
spancat_singlelabel_default_config
|
||||
)["model"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -52,39 +83,42 @@ class Suggester(Protocol):
|
|||
...
|
||||
|
||||
|
||||
def ngram_suggester(
|
||||
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
starts = ops.xp.arange(len(doc), dtype="i")
|
||||
starts = starts.reshape((-1, 1))
|
||||
length = 0
|
||||
for size in sizes:
|
||||
if size <= len(doc):
|
||||
starts_size = starts[: len(doc) - (size - 1)]
|
||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||
length += spans[-1].shape[0]
|
||||
if spans:
|
||||
assert spans[-1].ndim == 2, spans[-1].shape
|
||||
lengths.append(length)
|
||||
lengths_array = ops.asarray1i(lengths)
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
|
||||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
position."""
|
||||
|
||||
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
starts = ops.xp.arange(len(doc), dtype="i")
|
||||
starts = starts.reshape((-1, 1))
|
||||
length = 0
|
||||
for size in sizes:
|
||||
if size <= len(doc):
|
||||
starts_size = starts[: len(doc) - (size - 1)]
|
||||
spans.append(ops.xp.hstack((starts_size, starts_size + size)))
|
||||
length += spans[-1].shape[0]
|
||||
if spans:
|
||||
assert spans[-1].ndim == 2, spans[-1].shape
|
||||
lengths.append(length)
|
||||
lengths_array = ops.asarray1i(lengths)
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.xp.vstack(spans), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
|
||||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
return ngram_suggester
|
||||
return partial(ngram_suggester, sizes=sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||
|
@ -119,10 +153,14 @@ def make_spancat(
|
|||
threshold: float,
|
||||
max_positive: Optional[int],
|
||||
) -> "SpanCategorizer":
|
||||
"""Create a SpanCategorizer component. The span categorizer consists of two
|
||||
"""Create a SpanCategorizer component and configure it for multi-label
|
||||
classification to be able to assign multiple labels for each span.
|
||||
The span categorizer consists of two
|
||||
parts: a suggester function that proposes candidate spans, and a labeller
|
||||
model that predicts one or more labels for each span.
|
||||
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
|
@ -144,12 +182,80 @@ def make_spancat(
|
|||
"""
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
suggester=suggester,
|
||||
model=model,
|
||||
spans_key=spans_key,
|
||||
threshold=threshold,
|
||||
max_positive=max_positive,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=None,
|
||||
allow_overlap=True,
|
||||
max_positive=max_positive,
|
||||
threshold=threshold,
|
||||
scorer=scorer,
|
||||
add_negative_label=False,
|
||||
)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat_singlelabel",
|
||||
assigns=["doc.spans"],
|
||||
default_config={
|
||||
"spans_key": "sc",
|
||||
"model": DEFAULT_SPANCAT_SINGLELABEL_MODEL,
|
||||
"negative_weight": 1.0,
|
||||
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
|
||||
"scorer": {"@scorers": "spacy.spancat_scorer.v1"},
|
||||
"allow_overlap": True,
|
||||
},
|
||||
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
|
||||
)
|
||||
def make_spancat_singlelabel(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
suggester: Suggester,
|
||||
model: Model[Tuple[List[Doc], Ragged], Floats2d],
|
||||
spans_key: str,
|
||||
negative_weight: float,
|
||||
allow_overlap: bool,
|
||||
scorer: Optional[Callable],
|
||||
) -> "SpanCategorizer":
|
||||
"""Create a SpanCategorizer component and configure it for multi-class
|
||||
classification. With this configuration each span can get at most one
|
||||
label. The span categorizer consists of two
|
||||
parts: a suggester function that proposes candidate spans, and a labeller
|
||||
model that predicts one or more labels for each span.
|
||||
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
model (Model[Tuple[List[Doc], Ragged], Floats2d]): A model instance that
|
||||
is given a list of documents and (start, end) indices representing
|
||||
candidate span offsets. The model predicts a probability for each category
|
||||
for each span.
|
||||
spans_key (str): Key of the doc.spans dict to save the spans under. During
|
||||
initialization and training, the component will look for spans on the
|
||||
reference document under the same key.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
negative_weight (float): Multiplier for the loss terms.
|
||||
Can be used to downweight the negative samples if there are too many.
|
||||
allow_overlap (bool): If True the data is assumed to contain overlapping spans.
|
||||
Otherwise it produces non-overlapping spans greedily prioritizing
|
||||
higher assigned label scores.
|
||||
"""
|
||||
return SpanCategorizer(
|
||||
nlp.vocab,
|
||||
model=model,
|
||||
suggester=suggester,
|
||||
name=name,
|
||||
spans_key=spans_key,
|
||||
negative_weight=negative_weight,
|
||||
allow_overlap=allow_overlap,
|
||||
max_positive=1,
|
||||
add_negative_label=True,
|
||||
threshold=None,
|
||||
scorer=scorer,
|
||||
)
|
||||
|
||||
|
@ -172,6 +278,27 @@ def make_spancat_scorer():
|
|||
return spancat_score
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Intervals:
|
||||
"""
|
||||
Helper class to avoid storing overlapping spans.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.ranges = set()
|
||||
|
||||
def add(self, i, j):
|
||||
for e in range(i, j):
|
||||
self.ranges.add(e)
|
||||
|
||||
def __contains__(self, rang):
|
||||
i, j = rang
|
||||
for e in range(i, j):
|
||||
if e in self.ranges:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SpanCategorizer(TrainablePipe):
|
||||
"""Pipeline component to label spans of text.
|
||||
|
||||
|
@ -185,25 +312,43 @@ class SpanCategorizer(TrainablePipe):
|
|||
suggester: Suggester,
|
||||
name: str = "spancat",
|
||||
*,
|
||||
add_negative_label: bool = False,
|
||||
spans_key: str = "spans",
|
||||
threshold: float = 0.5,
|
||||
negative_weight: Optional[float] = 1.0,
|
||||
allow_overlap: Optional[bool] = True,
|
||||
max_positive: Optional[int] = None,
|
||||
threshold: Optional[float] = 0.5,
|
||||
scorer: Optional[Callable] = spancat_score,
|
||||
) -> None:
|
||||
"""Initialize the span categorizer.
|
||||
"""Initialize the multi-label or multi-class span categorizer.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
For multi-class classification (single label per span) we recommend
|
||||
using a Softmax classifier as a the final layer, while for multi-label
|
||||
classification (multiple possible labels per span) we recommend Logistic.
|
||||
suggester (Callable[[Iterable[Doc], Optional[Ops]], Ragged]): A function that suggests spans.
|
||||
Spans are returned as a ragged array with two integer columns, for the
|
||||
start and end positions.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
spans_key (str): Key of the Doc.spans dict to save the spans under.
|
||||
During initialization and training, the component will look for
|
||||
spans on the reference document under the same key. Defaults to
|
||||
`"spans"`.
|
||||
threshold (float): Minimum probability to consider a prediction
|
||||
positive. Spans with a positive prediction will be saved on the Doc.
|
||||
Defaults to 0.5.
|
||||
add_negative_label (bool): Learn to predict a special 'negative_label'
|
||||
when a Span is not annotated.
|
||||
threshold (Optional[float]): Minimum probability to consider a prediction
|
||||
positive. Defaults to 0.5. Spans with a positive prediction will be saved
|
||||
on the Doc.
|
||||
max_positive (Optional[int]): Maximum number of labels to consider
|
||||
positive per span. Defaults to None, indicating no limit.
|
||||
negative_weight (float): Multiplier for the loss terms.
|
||||
Can be used to downweight the negative samples if there are too many
|
||||
when add_negative_label is True. Otherwise its unused.
|
||||
allow_overlap (bool): If True the data is assumed to contain overlapping spans.
|
||||
Otherwise it produces non-overlapping spans greedily prioritizing
|
||||
higher assigned label scores. Only used when max_positive is 1.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to
|
||||
Scorer.score_spans for the Doc.spans[spans_key] with overlapping
|
||||
spans allowed.
|
||||
|
@ -215,12 +360,17 @@ class SpanCategorizer(TrainablePipe):
|
|||
"spans_key": spans_key,
|
||||
"threshold": threshold,
|
||||
"max_positive": max_positive,
|
||||
"negative_weight": negative_weight,
|
||||
"allow_overlap": allow_overlap,
|
||||
}
|
||||
self.vocab = vocab
|
||||
self.suggester = suggester
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.scorer = scorer
|
||||
self.add_negative_label = add_negative_label
|
||||
if not allow_overlap and max_positive is not None and max_positive > 1:
|
||||
raise ValueError(Errors.E1051.format(max_positive=max_positive))
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
|
@ -230,6 +380,21 @@ class SpanCategorizer(TrainablePipe):
|
|||
"""
|
||||
return str(self.cfg["spans_key"])
|
||||
|
||||
def _allow_extra_label(self) -> None:
|
||||
"""Raise an error if the component can not add any more labels."""
|
||||
nO = None
|
||||
if self.model.has_dim("nO"):
|
||||
nO = self.model.get_dim("nO")
|
||||
elif self.model.has_ref("output_layer") and self.model.get_ref(
|
||||
"output_layer"
|
||||
).has_dim("nO"):
|
||||
nO = self.model.get_ref("output_layer").get_dim("nO")
|
||||
if nO is not None and nO == self._n_labels:
|
||||
if not self.is_resizable:
|
||||
raise ValueError(
|
||||
Errors.E922.format(name=self.name, nO=self.model.get_dim("nO"))
|
||||
)
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
|
@ -263,6 +428,27 @@ class SpanCategorizer(TrainablePipe):
|
|||
"""
|
||||
return list(self.labels)
|
||||
|
||||
@property
|
||||
def _label_map(self) -> Dict[str, int]:
|
||||
"""RETURNS (Dict[str, int]): The label map."""
|
||||
return {label: i for i, label in enumerate(self.labels)}
|
||||
|
||||
@property
|
||||
def _n_labels(self) -> int:
|
||||
"""RETURNS (int): Number of labels."""
|
||||
if self.add_negative_label:
|
||||
return len(self.labels) + 1
|
||||
else:
|
||||
return len(self.labels)
|
||||
|
||||
@property
|
||||
def _negative_label_i(self) -> Union[int, None]:
|
||||
"""RETURNS (Union[int, None]): Index of the negative label."""
|
||||
if self.add_negative_label:
|
||||
return len(self.label_data)
|
||||
else:
|
||||
return None
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
|
@ -304,14 +490,24 @@ class SpanCategorizer(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/spancategorizer#set_annotations
|
||||
"""
|
||||
labels = self.labels
|
||||
indices, scores = indices_scores
|
||||
offset = 0
|
||||
for i, doc in enumerate(docs):
|
||||
indices_i = indices[i].dataXd
|
||||
doc.spans[self.key] = self._make_span_group(
|
||||
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
|
||||
)
|
||||
allow_overlap = cast(bool, self.cfg["allow_overlap"])
|
||||
if self.cfg["max_positive"] == 1:
|
||||
doc.spans[self.key] = self._make_span_group_singlelabel(
|
||||
doc,
|
||||
indices_i,
|
||||
scores[offset : offset + indices.lengths[i]],
|
||||
allow_overlap,
|
||||
)
|
||||
else:
|
||||
doc.spans[self.key] = self._make_span_group_multilabel(
|
||||
doc,
|
||||
indices_i,
|
||||
scores[offset : offset + indices.lengths[i]],
|
||||
)
|
||||
offset += indices.lengths[i]
|
||||
|
||||
def update(
|
||||
|
@ -371,9 +567,11 @@ class SpanCategorizer(TrainablePipe):
|
|||
spans = Ragged(
|
||||
self.model.ops.to_numpy(spans.data), self.model.ops.to_numpy(spans.lengths)
|
||||
)
|
||||
label_map = {label: i for i, label in enumerate(self.labels)}
|
||||
target = numpy.zeros(scores.shape, dtype=scores.dtype)
|
||||
if self.add_negative_label:
|
||||
negative_spans = numpy.ones((scores.shape[0]))
|
||||
offset = 0
|
||||
label_map = self._label_map
|
||||
for i, eg in enumerate(examples):
|
||||
# Map (start, end) offset of spans to the row in the d_scores array,
|
||||
# so that we can adjust the gradient for predictions that were
|
||||
|
@ -390,10 +588,16 @@ class SpanCategorizer(TrainablePipe):
|
|||
row = spans_index[key]
|
||||
k = label_map[gold_span.label_]
|
||||
target[row, k] = 1.0
|
||||
if self.add_negative_label:
|
||||
# delete negative label target.
|
||||
negative_spans[row] = 0.0
|
||||
# The target is a flat array for all docs. Track the position
|
||||
# we're at within the flat array.
|
||||
offset += spans.lengths[i]
|
||||
target = self.model.ops.asarray(target, dtype="f") # type: ignore
|
||||
if self.add_negative_label:
|
||||
negative_samples = numpy.nonzero(negative_spans)[0]
|
||||
target[negative_samples, self._negative_label_i] = 1.0 # type: ignore
|
||||
# The target will have the values 0 (for untrue predictions) or 1
|
||||
# (for true predictions).
|
||||
# The scores should be in the range [0, 1].
|
||||
|
@ -402,6 +606,10 @@ class SpanCategorizer(TrainablePipe):
|
|||
# If the prediction is 0.9 and it's false, the gradient will be
|
||||
# 0.9 (0.9 - 0.0)
|
||||
d_scores = scores - target
|
||||
if self.add_negative_label:
|
||||
neg_weight = cast(float, self.cfg["negative_weight"])
|
||||
if neg_weight != 1.0:
|
||||
d_scores[negative_samples] *= neg_weight
|
||||
loss = float((d_scores**2).sum())
|
||||
return loss, d_scores
|
||||
|
||||
|
@ -438,7 +646,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
if subbatch:
|
||||
docs = [eg.x for eg in subbatch]
|
||||
spans = build_ngram_suggester(sizes=[1])(docs)
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], len(self.labels))
|
||||
Y = self.model.ops.alloc2f(spans.dataXd.shape[0], self._n_labels)
|
||||
self.model.initialize(X=(docs, spans), Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
||||
|
@ -452,31 +660,98 @@ class SpanCategorizer(TrainablePipe):
|
|||
eg.reference.spans.get(self.key, []), allow_overlap=True
|
||||
)
|
||||
|
||||
def _make_span_group(
|
||||
self, doc: Doc, indices: Ints2d, scores: Floats2d, labels: List[str]
|
||||
def _make_span_group_multilabel(
|
||||
self,
|
||||
doc: Doc,
|
||||
indices: Ints2d,
|
||||
scores: Floats2d,
|
||||
) -> SpanGroup:
|
||||
"""Find the top-k labels for each span (k=max_positive)."""
|
||||
spans = SpanGroup(doc, name=self.key)
|
||||
max_positive = self.cfg["max_positive"]
|
||||
if scores.size == 0:
|
||||
return spans
|
||||
scores = self.model.ops.to_numpy(scores)
|
||||
indices = self.model.ops.to_numpy(indices)
|
||||
threshold = self.cfg["threshold"]
|
||||
max_positive = self.cfg["max_positive"]
|
||||
|
||||
keeps = scores >= threshold
|
||||
ranked = (scores * -1).argsort() # type: ignore
|
||||
if max_positive is not None:
|
||||
assert isinstance(max_positive, int)
|
||||
if self.add_negative_label:
|
||||
negative_scores = numpy.copy(scores[:, self._negative_label_i])
|
||||
scores[:, self._negative_label_i] = -numpy.inf
|
||||
ranked = (scores * -1).argsort() # type: ignore
|
||||
scores[:, self._negative_label_i] = negative_scores
|
||||
else:
|
||||
ranked = (scores * -1).argsort() # type: ignore
|
||||
span_filter = ranked[:, max_positive:]
|
||||
for i, row in enumerate(span_filter):
|
||||
keeps[i, row] = False
|
||||
spans.attrs["scores"] = scores[keeps].flatten()
|
||||
|
||||
indices = self.model.ops.to_numpy(indices)
|
||||
keeps = self.model.ops.to_numpy(keeps)
|
||||
|
||||
attrs_scores = []
|
||||
for i in range(indices.shape[0]):
|
||||
start = indices[i, 0]
|
||||
end = indices[i, 1]
|
||||
|
||||
for j, keep in enumerate(keeps[i]):
|
||||
if keep:
|
||||
spans.append(Span(doc, start, end, label=labels[j]))
|
||||
|
||||
if j != self._negative_label_i:
|
||||
spans.append(Span(doc, start, end, label=self.labels[j]))
|
||||
attrs_scores.append(scores[i, j])
|
||||
spans.attrs["scores"] = numpy.array(attrs_scores)
|
||||
return spans
|
||||
|
||||
def _make_span_group_singlelabel(
|
||||
self,
|
||||
doc: Doc,
|
||||
indices: Ints2d,
|
||||
scores: Floats2d,
|
||||
allow_overlap: bool = True,
|
||||
) -> SpanGroup:
|
||||
"""Find the argmax label for each span."""
|
||||
# Handle cases when there are zero suggestions
|
||||
if scores.size == 0:
|
||||
return SpanGroup(doc, name=self.key)
|
||||
scores = self.model.ops.to_numpy(scores)
|
||||
indices = self.model.ops.to_numpy(indices)
|
||||
predicted = scores.argmax(axis=1)
|
||||
argmax_scores = numpy.take_along_axis(
|
||||
scores, numpy.expand_dims(predicted, 1), axis=1
|
||||
)
|
||||
keeps = numpy.ones(predicted.shape, dtype=bool)
|
||||
# Remove samples where the negative label is the argmax.
|
||||
if self.add_negative_label:
|
||||
keeps = numpy.logical_and(keeps, predicted != self._negative_label_i)
|
||||
# Filter samples according to threshold.
|
||||
threshold = self.cfg["threshold"]
|
||||
if threshold is not None:
|
||||
keeps = numpy.logical_and(keeps, (argmax_scores >= threshold).squeeze())
|
||||
# Sort spans according to argmax probability
|
||||
if not allow_overlap:
|
||||
# Get the probabilities
|
||||
sort_idx = (argmax_scores.squeeze() * -1).argsort()
|
||||
argmax_scores = argmax_scores[sort_idx]
|
||||
predicted = predicted[sort_idx]
|
||||
indices = indices[sort_idx]
|
||||
keeps = keeps[sort_idx]
|
||||
seen = _Intervals()
|
||||
spans = SpanGroup(doc, name=self.key)
|
||||
attrs_scores = []
|
||||
for i in range(indices.shape[0]):
|
||||
if not keeps[i]:
|
||||
continue
|
||||
|
||||
label = predicted[i]
|
||||
start = indices[i, 0]
|
||||
end = indices[i, 1]
|
||||
|
||||
if not allow_overlap:
|
||||
if (start, end) in seen:
|
||||
continue
|
||||
else:
|
||||
seen.add(start, end)
|
||||
attrs_scores.append(argmax_scores[i])
|
||||
spans.append(Span(doc, start, end, label=self.labels[label]))
|
||||
|
||||
spans.attrs["scores"] = numpy.array(attrs_scores)
|
||||
return spans
|
||||
|
|
|
@ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!", "label_smoothing": 0.0},
|
||||
default_score_weights={"tag_acc": 1.0},
|
||||
)
|
||||
def make_tagger(
|
||||
|
@ -55,6 +55,7 @@ def make_tagger(
|
|||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
label_smoothing: float,
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
|
@ -63,7 +64,7 @@ def make_tagger(
|
|||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||
with the rows summing to 1).
|
||||
"""
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, label_smoothing=label_smoothing)
|
||||
|
||||
|
||||
def tagger_score(examples, **kwargs):
|
||||
|
@ -89,6 +90,7 @@ class Tagger(TrainablePipe):
|
|||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=tagger_score,
|
||||
neg_prefix="!",
|
||||
label_smoothing=0.0,
|
||||
):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
|
@ -105,7 +107,7 @@ class Tagger(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix, "label_smoothing": label_smoothing}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
|
||||
|
@ -256,7 +258,7 @@ class Tagger(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Tagger.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"], label_smoothing=self.cfg["label_smoothing"])
|
||||
# Convert empty tag "" to missing value None so that both misaligned
|
||||
# tokens and tokens with missing annotation have the default missing
|
||||
# value None.
|
||||
|
|
|
@ -74,7 +74,7 @@ subword_features = true
|
|||
default_config={
|
||||
"threshold": 0.0,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
|
@ -117,7 +117,7 @@ def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_scorer.v1")
|
||||
@registry.scorers("spacy.textcat_scorer.v2")
|
||||
def make_textcat_scorer():
|
||||
return textcat_score
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ subword_features = true
|
|||
default_config={
|
||||
"threshold": 0.5,
|
||||
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
|
||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v2"},
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
|
@ -120,7 +120,7 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,
|
|||
)
|
||||
|
||||
|
||||
@registry.scorers("spacy.textcat_multilabel_scorer.v1")
|
||||
@registry.scorers("spacy.textcat_multilabel_scorer.v2")
|
||||
def make_textcat_multilabel_scorer():
|
||||
return textcat_multilabel_score
|
||||
|
||||
|
|
|
@ -156,12 +156,40 @@ def validate_token_pattern(obj: list) -> List[str]:
|
|||
|
||||
|
||||
class TokenPatternString(BaseModel):
|
||||
REGEX: Optional[StrictStr] = Field(None, alias="regex")
|
||||
REGEX: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="regex")
|
||||
IN: Optional[List[StrictStr]] = Field(None, alias="in")
|
||||
NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in")
|
||||
IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset")
|
||||
IS_SUPERSET: Optional[List[StrictStr]] = Field(None, alias="is_superset")
|
||||
INTERSECTS: Optional[List[StrictStr]] = Field(None, alias="intersects")
|
||||
FUZZY: Optional[Union[StrictStr, "TokenPatternString"]] = Field(None, alias="fuzzy")
|
||||
FUZZY1: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy1"
|
||||
)
|
||||
FUZZY2: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy2"
|
||||
)
|
||||
FUZZY3: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy3"
|
||||
)
|
||||
FUZZY4: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy4"
|
||||
)
|
||||
FUZZY5: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy5"
|
||||
)
|
||||
FUZZY6: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy6"
|
||||
)
|
||||
FUZZY7: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy7"
|
||||
)
|
||||
FUZZY8: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy8"
|
||||
)
|
||||
FUZZY9: Optional[Union[StrictStr, "TokenPatternString"]] = Field(
|
||||
None, alias="fuzzy9"
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
|
|
@ -174,7 +174,7 @@ class Scorer:
|
|||
prf_score.score_set(pred_spans, gold_spans)
|
||||
if len(acc_score) > 0:
|
||||
return {
|
||||
"token_acc": acc_score.fscore,
|
||||
"token_acc": acc_score.precision,
|
||||
"token_p": prf_score.precision,
|
||||
"token_r": prf_score.recall,
|
||||
"token_f": prf_score.fscore,
|
||||
|
@ -476,14 +476,12 @@ class Scorer:
|
|||
f_per_type = {label: PRFScore() for label in labels}
|
||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||
labels = set(labels)
|
||||
if labels:
|
||||
for eg in examples:
|
||||
labels.update(eg.predicted.cats.keys())
|
||||
labels.update(eg.reference.cats.keys())
|
||||
for example in examples:
|
||||
# Through this loop, None in the gold_cats indicates missing label.
|
||||
pred_cats = getter(example.predicted, attr)
|
||||
pred_cats = {k: v for k, v in pred_cats.items() if k in labels}
|
||||
gold_cats = getter(example.reference, attr)
|
||||
gold_cats = {k: v for k, v in gold_cats.items() if k in labels}
|
||||
|
||||
for label in labels:
|
||||
pred_score = pred_cats.get(label, 0.0)
|
||||
|
|
|
@ -163,6 +163,18 @@ def test_char_span(doc, i_sent, i, j, text):
|
|||
assert span.text == text
|
||||
|
||||
|
||||
def test_char_span_attributes(doc):
|
||||
label = "LABEL"
|
||||
kb_id = "KB_ID"
|
||||
span_id = "SPAN_ID"
|
||||
span1 = doc.char_span(20, 45, label=label, kb_id=kb_id, span_id=span_id)
|
||||
span2 = doc[1:].char_span(15, 40, label=label, kb_id=kb_id, span_id=span_id)
|
||||
assert span1.text == span2.text
|
||||
assert span1.label_ == span2.label_ == label
|
||||
assert span1.kb_id_ == span2.kb_id_ == kb_id
|
||||
assert span1.id_ == span2.id_ == span_id
|
||||
|
||||
|
||||
def test_spans_sent_spans(doc):
|
||||
sents = list(doc.sents)
|
||||
assert sents[0].start == 0
|
||||
|
@ -367,6 +379,14 @@ def test_spans_by_character(doc):
|
|||
span1.start_char + 1, span1.end_char, label="GPE", alignment_mode="unk"
|
||||
)
|
||||
|
||||
# Span.char_span + alignment mode "contract"
|
||||
span2 = doc[0:2].char_span(
|
||||
span1.start_char - 3, span1.end_char, label="GPE", alignment_mode="contract"
|
||||
)
|
||||
assert span1.start_char == span2.start_char
|
||||
assert span1.end_char == span2.end_char
|
||||
assert span2.label_ == "GPE"
|
||||
|
||||
|
||||
def test_span_to_array(doc):
|
||||
span = doc[1:-2]
|
||||
|
@ -696,3 +716,18 @@ def test_for_partial_ent_sents():
|
|||
# equal to the sentences referenced in ent.sents.
|
||||
for doc_sent, ent_sent in zip(doc.sents, doc.ents[0].sents):
|
||||
assert doc_sent == ent_sent
|
||||
|
||||
|
||||
def test_for_no_ent_sents():
|
||||
"""Span.sents() should set .sents correctly, even if Span in question is trailing and doesn't form a full
|
||||
sentence.
|
||||
"""
|
||||
doc = Doc(
|
||||
English().vocab,
|
||||
words=["This", "is", "a", "test.", "ENTITY"],
|
||||
sent_starts=[1, 0, 0, 0, 1],
|
||||
)
|
||||
doc.set_ents([Span(doc, 4, 5, "WORK")])
|
||||
sents = list(doc.ents[0].sents)
|
||||
assert len(sents) == 1
|
||||
assert str(sents[0]) == str(doc.ents[0].sent) == "ENTITY"
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
from random import Random
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.tokens import Span, SpanGroup
|
||||
from spacy.tokens import Span, SpanGroup, Doc
|
||||
from spacy.util import filter_spans
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -240,3 +243,13 @@ def test_span_group_extend(doc):
|
|||
def test_span_group_dealloc(span_group):
|
||||
with pytest.raises(AttributeError):
|
||||
print(span_group.doc)
|
||||
|
||||
|
||||
@pytest.mark.issue(11975)
|
||||
def test_span_group_typing(doc: Doc):
|
||||
"""Tests whether typing of `SpanGroup` as `Iterable[Span]`-like object is accepted by mypy."""
|
||||
span_group: SpanGroup = doc.spans["SPANS"]
|
||||
spans: List[Span] = list(span_group)
|
||||
for i, span in enumerate(span_group):
|
||||
assert span == span_group[i] == spans[i]
|
||||
filter_spans(span_group)
|
||||
|
|
|
@ -32,3 +32,10 @@ def test_tokenizer_splits_comma_infix(sv_tokenizer, text):
|
|||
def test_tokenizer_splits_ellipsis_infix(sv_tokenizer, text):
|
||||
tokens = sv_tokenizer(text)
|
||||
assert len(tokens) == 3
|
||||
|
||||
|
||||
@pytest.mark.issue(12311)
|
||||
@pytest.mark.parametrize("text", ["99:e", "c:a", "EU:s", "Maj:t"])
|
||||
def test_sv_tokenizer_handles_colon(sv_tokenizer, text):
|
||||
tokens = sv_tokenizer(text)
|
||||
assert len(tokens) == 1
|
||||
|
|
|
@ -316,16 +316,32 @@ def test_dependency_matcher_precedence_ops(en_vocab, op, num_matches):
|
|||
("the", "brown", "$--", 0),
|
||||
("brown", "the", "$--", 1),
|
||||
("brown", "brown", "$--", 0),
|
||||
("over", "jumped", "<+", 0),
|
||||
("quick", "fox", "<+", 0),
|
||||
("the", "quick", "<+", 0),
|
||||
("brown", "fox", "<+", 1),
|
||||
("quick", "fox", "<++", 1),
|
||||
("quick", "over", "<++", 0),
|
||||
("over", "jumped", "<++", 0),
|
||||
("the", "fox", "<++", 2),
|
||||
("brown", "fox", "<-", 0),
|
||||
("fox", "over", "<-", 0),
|
||||
("the", "over", "<-", 0),
|
||||
("over", "jumped", "<-", 1),
|
||||
("brown", "fox", "<--", 0),
|
||||
("fox", "jumped", "<--", 0),
|
||||
("fox", "over", "<--", 1),
|
||||
("fox", "brown", ">+", 0),
|
||||
("over", "fox", ">+", 0),
|
||||
("over", "the", ">+", 0),
|
||||
("jumped", "over", ">+", 1),
|
||||
("jumped", "over", ">++", 1),
|
||||
("fox", "lazy", ">++", 0),
|
||||
("over", "the", ">++", 0),
|
||||
("jumped", "over", ">-", 0),
|
||||
("fox", "quick", ">-", 0),
|
||||
("brown", "quick", ">-", 0),
|
||||
("fox", "brown", ">-", 1),
|
||||
("brown", "fox", ">--", 0),
|
||||
("fox", "brown", ">--", 1),
|
||||
("jumped", "fox", ">--", 1),
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from spacy.matcher import levenshtein
|
||||
from spacy.matcher.levenshtein import levenshtein_compare
|
||||
|
||||
|
||||
# empty string plus 10 random ASCII, 10 random unicode, and 2 random long tests
|
||||
|
@ -42,3 +43,31 @@ from spacy.matcher import levenshtein
|
|||
)
|
||||
def test_levenshtein(dist, a, b):
|
||||
assert levenshtein(a, b) == dist
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a,b,fuzzy,expected",
|
||||
[
|
||||
("a", "a", 1, True),
|
||||
("a", "a", 0, True),
|
||||
("a", "a", -1, True),
|
||||
("a", "ab", 1, True),
|
||||
("a", "ab", 0, False),
|
||||
("a", "ab", -1, True),
|
||||
("ab", "ac", 1, True),
|
||||
("ab", "ac", -1, True),
|
||||
("abc", "cde", 4, True),
|
||||
("abc", "cde", -1, False),
|
||||
("abcdef", "cdefgh", 4, True),
|
||||
("abcdef", "cdefgh", 3, False),
|
||||
("abcdef", "cdefgh", -1, False), # default (2 for length 6)
|
||||
("abcdefgh", "cdefghijk", 5, True),
|
||||
("abcdefgh", "cdefghijk", 4, False),
|
||||
("abcdefgh", "cdefghijk", -1, False), # default (2)
|
||||
("abcdefgh", "cdefghijkl", 6, True),
|
||||
("abcdefgh", "cdefghijkl", 5, False),
|
||||
("abcdefgh", "cdefghijkl", -1, False), # default (2)
|
||||
],
|
||||
)
|
||||
def test_levenshtein_compare(a, b, fuzzy, expected):
|
||||
assert levenshtein_compare(a, b, fuzzy) == expected
|
||||
|
|
|
@ -118,6 +118,155 @@ def test_matcher_match_multi(matcher):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rules,match_locs",
|
||||
[
|
||||
(
|
||||
{
|
||||
"GoogleNow": [[{"ORTH": {"FUZZY": "Google"}}, {"ORTH": "Now"}]],
|
||||
},
|
||||
[(2, 4)],
|
||||
),
|
||||
(
|
||||
{
|
||||
"Java": [[{"LOWER": {"FUZZY": "java"}}]],
|
||||
},
|
||||
[(5, 6)],
|
||||
),
|
||||
(
|
||||
{
|
||||
"JS": [[{"ORTH": {"FUZZY": "JavaScript"}}]],
|
||||
"GoogleNow": [[{"ORTH": {"FUZZY": "Google"}}, {"ORTH": "Now"}]],
|
||||
"Java": [[{"LOWER": {"FUZZY": "java"}}]],
|
||||
},
|
||||
[(2, 4), (5, 6), (8, 9)],
|
||||
),
|
||||
# only the second pattern matches (check that predicate keys used for
|
||||
# caching don't collide)
|
||||
(
|
||||
{
|
||||
"A": [[{"ORTH": {"FUZZY": "Javascripts"}}]],
|
||||
"B": [[{"ORTH": {"FUZZY5": "Javascripts"}}]],
|
||||
},
|
||||
[(8, 9)],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_matcher_match_fuzzy(en_vocab, rules, match_locs):
|
||||
words = ["They", "like", "Goggle", "Now", "and", "Jav", "but", "not", "JvvaScrpt"]
|
||||
doc = Doc(en_vocab, words=words)
|
||||
|
||||
matcher = Matcher(en_vocab)
|
||||
for key, patterns in rules.items():
|
||||
matcher.add(key, patterns)
|
||||
assert match_locs == [(start, end) for m_id, start, end in matcher(doc)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("set_op", ["IN", "NOT_IN"])
|
||||
def test_matcher_match_fuzzy_set_op_longest(en_vocab, set_op):
|
||||
rules = {
|
||||
"GoogleNow": [[{"ORTH": {"FUZZY": {set_op: ["Google", "Now"]}}, "OP": "+"}]]
|
||||
}
|
||||
matcher = Matcher(en_vocab)
|
||||
for key, patterns in rules.items():
|
||||
matcher.add(key, patterns, greedy="LONGEST")
|
||||
|
||||
words = ["They", "like", "Goggle", "Noo"]
|
||||
doc = Doc(en_vocab, words=words)
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
|
||||
def test_matcher_match_fuzzy_set_multiple(en_vocab):
|
||||
rules = {
|
||||
"GoogleNow": [
|
||||
[
|
||||
{
|
||||
"ORTH": {"FUZZY": {"IN": ["Google", "Now"]}, "NOT_IN": ["Goggle"]},
|
||||
"OP": "+",
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
matcher = Matcher(en_vocab)
|
||||
for key, patterns in rules.items():
|
||||
matcher.add(key, patterns, greedy="LONGEST")
|
||||
|
||||
words = ["They", "like", "Goggle", "Noo"]
|
||||
doc = Doc(matcher.vocab, words=words)
|
||||
assert matcher(doc) == [
|
||||
(doc.vocab.strings["GoogleNow"], 3, 4),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fuzzyn", range(1, 10))
|
||||
def test_matcher_match_fuzzyn_all_insertions(en_vocab, fuzzyn):
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("GoogleNow", [[{"ORTH": {f"FUZZY{fuzzyn}": "GoogleNow"}}]])
|
||||
# words with increasing edit distance
|
||||
words = ["GoogleNow" + "a" * i for i in range(0, 10)]
|
||||
doc = Doc(en_vocab, words)
|
||||
assert len(matcher(doc)) == fuzzyn + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fuzzyn", range(1, 6))
|
||||
def test_matcher_match_fuzzyn_various_edits(en_vocab, fuzzyn):
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("GoogleNow", [[{"ORTH": {f"FUZZY{fuzzyn}": "GoogleNow"}}]])
|
||||
# words with increasing edit distance of different edit types
|
||||
words = [
|
||||
"GoogleNow",
|
||||
"GoogleNuw",
|
||||
"GoogleNuew",
|
||||
"GoogleNoweee",
|
||||
"GiggleNuw3",
|
||||
"gouggle5New",
|
||||
]
|
||||
doc = Doc(en_vocab, words)
|
||||
assert len(matcher(doc)) == fuzzyn + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("greedy", ["FIRST", "LONGEST"])
|
||||
@pytest.mark.parametrize("set_op", ["IN", "NOT_IN"])
|
||||
def test_matcher_match_fuzzyn_set_op_longest(en_vocab, greedy, set_op):
|
||||
rules = {
|
||||
"GoogleNow": [[{"ORTH": {"FUZZY2": {set_op: ["Google", "Now"]}}, "OP": "+"}]]
|
||||
}
|
||||
matcher = Matcher(en_vocab)
|
||||
for key, patterns in rules.items():
|
||||
matcher.add(key, patterns, greedy=greedy)
|
||||
|
||||
words = ["They", "like", "Goggle", "Noo"]
|
||||
doc = Doc(matcher.vocab, words=words)
|
||||
spans = matcher(doc, as_spans=True)
|
||||
assert len(spans) == 1
|
||||
if set_op == "IN":
|
||||
assert spans[0].text == "Goggle Noo"
|
||||
else:
|
||||
assert spans[0].text == "They like"
|
||||
|
||||
|
||||
def test_matcher_match_fuzzyn_set_multiple(en_vocab):
|
||||
rules = {
|
||||
"GoogleNow": [
|
||||
[
|
||||
{
|
||||
"ORTH": {"FUZZY1": {"IN": ["Google", "Now"]}, "NOT_IN": ["Goggle"]},
|
||||
"OP": "+",
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
matcher = Matcher(en_vocab)
|
||||
for key, patterns in rules.items():
|
||||
matcher.add(key, patterns, greedy="LONGEST")
|
||||
|
||||
words = ["They", "like", "Goggle", "Noo"]
|
||||
doc = Doc(matcher.vocab, words=words)
|
||||
assert matcher(doc) == [
|
||||
(doc.vocab.strings["GoogleNow"], 3, 4),
|
||||
]
|
||||
|
||||
|
||||
def test_matcher_empty_dict(en_vocab):
|
||||
"""Test matcher allows empty token specs, meaning match on any token."""
|
||||
matcher = Matcher(en_vocab)
|
||||
|
@ -437,6 +586,30 @@ def test_matcher_regex(en_vocab):
|
|||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_matcher_regex_set_in(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"ORTH": {"REGEX": {"IN": [r"(?:a)", r"(?:an)"]}}}]
|
||||
matcher.add("A_OR_AN", [pattern])
|
||||
doc = Doc(en_vocab, words=["an", "a", "hi"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=["bye"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_matcher_regex_set_not_in(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"ORTH": {"REGEX": {"NOT_IN": [r"(?:a)", r"(?:an)"]}}}]
|
||||
matcher.add("A_OR_AN", [pattern])
|
||||
doc = Doc(en_vocab, words=["an", "a", "hi"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
doc = Doc(en_vocab, words=["bye"])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_matcher_regex_shape(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"SHAPE": {"REGEX": r"^[^x]+$"}}]
|
||||
|
|
|
@ -9,6 +9,8 @@ from spacy.lang.en import English
|
|||
from spacy.lang.it import Italian
|
||||
from spacy.language import Language
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.pipeline import EntityRecognizer
|
||||
from spacy.pipeline.ner import DEFAULT_NER_MODEL
|
||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||
from spacy.training import Example, iob_to_biluo, split_bilu_label
|
||||
from spacy.tokens import Doc, Span
|
||||
|
@ -16,8 +18,6 @@ from spacy.vocab import Vocab
|
|||
import logging
|
||||
|
||||
from ..util import make_tempdir
|
||||
from ...pipeline import EntityRecognizer
|
||||
from ...pipeline.ner import DEFAULT_NER_MODEL
|
||||
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
||||
|
|
|
@ -8,11 +8,11 @@ from spacy.lang.en import English
|
|||
from spacy.tokens import Doc
|
||||
from spacy.training import Example
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
|
||||
from ...pipeline import DependencyParser
|
||||
from ...pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
from ..util import apply_transition_sequence, make_tempdir
|
||||
from ...pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
|
|
|
@ -101,14 +101,15 @@ def test_initialize_from_labels():
|
|||
}
|
||||
|
||||
|
||||
def test_no_data():
|
||||
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
||||
def test_no_data(top_k):
|
||||
# Test that the lemmatizer provides a nice error when there's no tagging data / labels
|
||||
TEXTCAT_DATA = [
|
||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||
]
|
||||
nlp = English()
|
||||
nlp.add_pipe("trainable_lemmatizer")
|
||||
nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
||||
nlp.add_pipe("textcat")
|
||||
|
||||
train_examples = []
|
||||
|
@ -119,10 +120,11 @@ def test_no_data():
|
|||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
|
||||
def test_incomplete_data():
|
||||
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
||||
def test_incomplete_data(top_k):
|
||||
# Test that the lemmatizer works with incomplete information
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
||||
lemmatizer.min_tree_freq = 1
|
||||
train_examples = []
|
||||
for t in PARTIAL_DATA:
|
||||
|
@ -139,10 +141,25 @@ def test_incomplete_data():
|
|||
assert doc[1].lemma_ == "like"
|
||||
assert doc[2].lemma_ == "blue"
|
||||
|
||||
# Check that incomplete annotations are ignored.
|
||||
scores, _ = lemmatizer.model([eg.predicted for eg in train_examples], is_train=True)
|
||||
_, dX = lemmatizer.get_loss(train_examples, scores)
|
||||
xp = lemmatizer.model.ops.xp
|
||||
|
||||
def test_overfitting_IO():
|
||||
# Missing annotations.
|
||||
assert xp.count_nonzero(dX[0][0]) == 0
|
||||
assert xp.count_nonzero(dX[0][3]) == 0
|
||||
assert xp.count_nonzero(dX[1][0]) == 0
|
||||
assert xp.count_nonzero(dX[1][3]) == 0
|
||||
|
||||
# Misaligned annotations.
|
||||
assert xp.count_nonzero(dX[1][1]) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_k", (1, 5, 30))
|
||||
def test_overfitting_IO(top_k):
|
||||
nlp = English()
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer")
|
||||
lemmatizer = nlp.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
||||
lemmatizer.min_tree_freq = 1
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
|
@ -175,7 +192,7 @@ def test_overfitting_IO():
|
|||
# Check model after a {to,from}_bytes roundtrip
|
||||
nlp_bytes = nlp.to_bytes()
|
||||
nlp3 = English()
|
||||
nlp3.add_pipe("trainable_lemmatizer")
|
||||
nlp3.add_pipe("trainable_lemmatizer", config={"top_k": top_k})
|
||||
nlp3.from_bytes(nlp_bytes)
|
||||
doc3 = nlp3(test_text)
|
||||
assert doc3[0].lemma_ == "she"
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Callable, Iterable, Dict, Any, Iterator
|
||||
from typing import Callable, Iterable, Dict, Any, Iterator, Tuple
|
||||
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
from spacy import registry, util
|
||||
from spacy import registry, util, Language
|
||||
from spacy.attrs import ENT_KB_ID
|
||||
from spacy.compat import pickle
|
||||
from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase
|
||||
|
@ -108,18 +108,23 @@ def test_issue7065():
|
|||
|
||||
|
||||
@pytest.mark.issue(7065)
|
||||
def test_issue7065_b():
|
||||
@pytest.mark.parametrize("entity_in_first_sentence", [True, False])
|
||||
def test_sentence_crossing_ents(entity_in_first_sentence: bool):
|
||||
"""Tests if NEL crashes if entities cross sentence boundaries and the first associated sentence doesn't have an
|
||||
entity.
|
||||
entity_in_prior_sentence (bool): Whether to include an entity in the first sentence associated with the
|
||||
sentence-crossing entity.
|
||||
"""
|
||||
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
|
||||
nlp = English()
|
||||
vector_length = 3
|
||||
nlp.add_pipe("sentencizer")
|
||||
text = "Mahler 's Symphony No. 8 was beautiful."
|
||||
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
|
||||
links = {
|
||||
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
|
||||
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
|
||||
}
|
||||
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
|
||||
entities = [(10, 24, "WORK")]
|
||||
links = {(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
|
||||
if entity_in_first_sentence:
|
||||
entities.append((0, 6, "PERSON"))
|
||||
links[(0, 6)] = {"Q7304": 1.0, "Q270853": 0.0}
|
||||
sent_starts = [1, -1, 0, 0, 0, 1, 0, 0, 0]
|
||||
doc = nlp(text)
|
||||
example = Example.from_dict(
|
||||
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
|
||||
|
@ -145,31 +150,14 @@ def test_issue7065_b():
|
|||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||
entity_linker.set_kb(create_kb)
|
||||
entity_linker.set_kb(create_kb) # type: ignore
|
||||
# train the NEL pipe
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(2):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
nlp.update(train_examples, sgd=optimizer)
|
||||
|
||||
# Add a custom rule-based component to mimick NER
|
||||
patterns = [
|
||||
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
|
||||
{
|
||||
"label": "WORK",
|
||||
"pattern": [
|
||||
{"LOWER": "symphony"},
|
||||
{"LOWER": "no"},
|
||||
{"LOWER": "."},
|
||||
{"LOWER": "8"},
|
||||
],
|
||||
},
|
||||
]
|
||||
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
||||
ruler.add_patterns(patterns)
|
||||
# test the trained model - this should not throw E148
|
||||
doc = nlp(text)
|
||||
assert doc
|
||||
# This shouldn't crash.
|
||||
entity_linker.predict([example.reference]) # type: ignore
|
||||
|
||||
|
||||
def test_no_entities():
|
||||
|
@ -353,6 +341,9 @@ def test_kb_default(nlp):
|
|||
"""Test that the default (empty) KB is loaded upon construction"""
|
||||
entity_linker = nlp.add_pipe("entity_linker", config={})
|
||||
assert len(entity_linker.kb) == 0
|
||||
with pytest.raises(ValueError, match="E139"):
|
||||
# this raises an error because the KB is empty
|
||||
entity_linker.validate_kb()
|
||||
assert entity_linker.kb.get_size_entities() == 0
|
||||
assert entity_linker.kb.get_size_aliases() == 0
|
||||
# 64 is the default value from pipeline.entity_linker
|
||||
|
|
|
@ -382,6 +382,43 @@ def test_entity_ruler_overlapping_spans(nlp, entity_ruler_factory):
|
|||
assert doc.ents[0].label_ == "FOOBAR"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy_pipe(nlp, entity_ruler_factory):
|
||||
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
assert len(doc.ents) == 1
|
||||
assert doc.ents[0].label_ == "HELLO"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy(nlp, entity_ruler_factory):
|
||||
ruler = nlp.add_pipe(entity_ruler_factory, name="entity_ruler")
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
assert len(doc.ents) == 1
|
||||
assert doc.ents[0].label_ == "HELLO"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_fuzzy_disabled(nlp, entity_ruler_factory):
|
||||
@registry.misc("test_fuzzy_compare_disabled")
|
||||
def make_test_fuzzy_compare_disabled():
|
||||
return lambda x, y, z: False
|
||||
|
||||
ruler = nlp.add_pipe(
|
||||
entity_ruler_factory,
|
||||
name="entity_ruler",
|
||||
config={"matcher_fuzzy_compare": {"@misc": "test_fuzzy_compare_disabled"}},
|
||||
)
|
||||
patterns = [{"label": "HELLO", "pattern": [{"LOWER": {"FUZZY": "hello"}}]}]
|
||||
ruler.add_patterns(patterns)
|
||||
doc = nlp("helloo")
|
||||
assert len(doc.ents) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
@pytest.mark.parametrize("entity_ruler_factory", ENTITY_RULERS)
|
||||
def test_entity_ruler_multiprocessing(nlp, n_process, entity_ruler_factory):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from numpy.testing import assert_equal, assert_almost_equal
|
||||
|
||||
from thinc.api import get_current_ops
|
||||
|
||||
from spacy import util
|
||||
from spacy.training import Example
|
||||
|
@ -19,6 +21,8 @@ def test_label_types():
|
|||
morphologizer.add_label(9)
|
||||
|
||||
|
||||
TAGS = ["Feat=N", "Feat=V", "Feat=J"]
|
||||
|
||||
TRAIN_DATA = [
|
||||
(
|
||||
"I like green eggs",
|
||||
|
@ -32,6 +36,30 @@ TRAIN_DATA = [
|
|||
]
|
||||
|
||||
|
||||
def test_label_smoothing():
|
||||
nlp = Language()
|
||||
morph_no_ls = nlp.add_pipe("morphologizer", "no_label_smoothing")
|
||||
morph_ls = nlp.add_pipe(
|
||||
"morphologizer", "label_smoothing", config=dict(label_smoothing=0.05)
|
||||
)
|
||||
train_examples = []
|
||||
losses = {}
|
||||
for tag in TAGS:
|
||||
morph_no_ls.add_label(tag)
|
||||
morph_ls.add_label(tag)
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
ops = get_current_ops()
|
||||
no_ls_grads = ops.to_numpy(morph_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
ls_grads = ops.to_numpy(morph_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)
|
||||
|
||||
|
||||
def test_no_label():
|
||||
nlp = Language()
|
||||
nlp.add_pipe("morphologizer")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||
from thinc.api import get_current_ops, Ragged
|
||||
from thinc.api import get_current_ops, NumpyOps, Ragged
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
|
@ -15,6 +15,8 @@ OPS = get_current_ops()
|
|||
|
||||
SPAN_KEY = "labeled_spans"
|
||||
|
||||
SPANCAT_COMPONENTS = ["spancat", "spancat_singlelabel"]
|
||||
|
||||
TRAIN_DATA = [
|
||||
("Who is Shaka Khan?", {"spans": {SPAN_KEY: [(7, 17, "PERSON")]}}),
|
||||
(
|
||||
|
@ -41,38 +43,42 @@ def make_examples(nlp, data=TRAIN_DATA):
|
|||
return train_examples
|
||||
|
||||
|
||||
def test_no_label():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_no_label(name):
|
||||
nlp = Language()
|
||||
nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
with pytest.raises(ValueError):
|
||||
nlp.initialize()
|
||||
|
||||
|
||||
def test_no_resize():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_no_resize(name):
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
spancat.add_label("Thing")
|
||||
spancat.add_label("Phrase")
|
||||
assert spancat.labels == ("Thing", "Phrase")
|
||||
nlp.initialize()
|
||||
assert spancat.model.get_dim("nO") == 2
|
||||
assert spancat.model.get_dim("nO") == spancat._n_labels
|
||||
# this throws an error because the spancat can't be resized after initialization
|
||||
with pytest.raises(ValueError):
|
||||
spancat.add_label("Stuff")
|
||||
|
||||
|
||||
def test_implicit_labels():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_implicit_labels(name):
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
assert len(spancat.labels) == 0
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert spancat.labels == ("PERSON", "LOC")
|
||||
|
||||
|
||||
def test_explicit_labels():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_explicit_labels(name):
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
assert len(spancat.labels) == 0
|
||||
spancat.add_label("PERSON")
|
||||
spancat.add_label("LOC")
|
||||
|
@ -102,13 +108,13 @@ def test_doc_gc():
|
|||
# XXX This fails with length 0 sometimes
|
||||
assert len(spangroup) > 0
|
||||
with pytest.raises(RuntimeError):
|
||||
span = spangroup[0]
|
||||
spangroup[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_positive,nr_results", [(None, 4), (1, 2), (2, 3), (3, 4), (4, 4)]
|
||||
)
|
||||
def test_make_spangroup(max_positive, nr_results):
|
||||
def test_make_spangroup_multilabel(max_positive, nr_results):
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(
|
||||
|
@ -120,10 +126,12 @@ def test_make_spangroup(max_positive, nr_results):
|
|||
indices = ngram_suggester([doc])[0].dataXd
|
||||
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||
for label in labels:
|
||||
spancat.add_label(label)
|
||||
scores = numpy.asarray(
|
||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||
)
|
||||
spangroup = spancat._make_span_group(doc, indices, scores, labels)
|
||||
spangroup = spancat._make_span_group_multilabel(doc, indices, scores)
|
||||
assert len(spangroup) == nr_results
|
||||
|
||||
# first span is always the second token "London"
|
||||
|
@ -154,6 +162,130 @@ def test_make_spangroup(max_positive, nr_results):
|
|||
assert_almost_equal(0.9, spangroup.attrs["scores"][-1], 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"threshold,allow_overlap,nr_results",
|
||||
[(0.05, True, 3), (0.05, False, 1), (0.5, True, 2), (0.5, False, 1)],
|
||||
)
|
||||
def test_make_spangroup_singlelabel(threshold, allow_overlap, nr_results):
|
||||
fix_random_seed(0)
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(
|
||||
"spancat",
|
||||
config={
|
||||
"spans_key": SPAN_KEY,
|
||||
"threshold": threshold,
|
||||
"max_positive": 1,
|
||||
},
|
||||
)
|
||||
doc = nlp.make_doc("Greater London")
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||
indices = ngram_suggester([doc])[0].dataXd
|
||||
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||
for label in labels:
|
||||
spancat.add_label(label)
|
||||
scores = numpy.asarray(
|
||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||
)
|
||||
spangroup = spancat._make_span_group_singlelabel(
|
||||
doc, indices, scores, allow_overlap
|
||||
)
|
||||
if threshold > 0.4:
|
||||
if allow_overlap:
|
||||
assert spangroup[0].text == "London"
|
||||
assert spangroup[0].label_ == "City"
|
||||
assert_almost_equal(0.6, spangroup.attrs["scores"][0], 5)
|
||||
assert spangroup[1].text == "Greater London"
|
||||
assert spangroup[1].label_ == "GreatCity"
|
||||
assert spangroup.attrs["scores"][1] == 0.9
|
||||
assert_almost_equal(0.9, spangroup.attrs["scores"][1], 5)
|
||||
else:
|
||||
assert spangroup[0].text == "Greater London"
|
||||
assert spangroup[0].label_ == "GreatCity"
|
||||
assert spangroup.attrs["scores"][0] == 0.9
|
||||
else:
|
||||
if allow_overlap:
|
||||
assert spangroup[0].text == "Greater"
|
||||
assert spangroup[0].label_ == "City"
|
||||
assert spangroup[1].text == "London"
|
||||
assert spangroup[1].label_ == "City"
|
||||
assert spangroup[2].text == "Greater London"
|
||||
assert spangroup[2].label_ == "GreatCity"
|
||||
else:
|
||||
assert spangroup[0].text == "Greater London"
|
||||
|
||||
|
||||
def test_make_spangroup_negative_label():
|
||||
fix_random_seed(0)
|
||||
nlp_single = Language()
|
||||
nlp_multi = Language()
|
||||
spancat_single = nlp_single.add_pipe(
|
||||
"spancat",
|
||||
config={
|
||||
"spans_key": SPAN_KEY,
|
||||
"threshold": 0.1,
|
||||
"max_positive": 1,
|
||||
},
|
||||
)
|
||||
spancat_multi = nlp_multi.add_pipe(
|
||||
"spancat",
|
||||
config={
|
||||
"spans_key": SPAN_KEY,
|
||||
"threshold": 0.1,
|
||||
"max_positive": 2,
|
||||
},
|
||||
)
|
||||
spancat_single.add_negative_label = True
|
||||
spancat_multi.add_negative_label = True
|
||||
doc = nlp_single.make_doc("Greater London")
|
||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||
for label in labels:
|
||||
spancat_multi.add_label(label)
|
||||
spancat_single.add_label(label)
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||
indices = ngram_suggester([doc])[0].dataXd
|
||||
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||
scores = numpy.asarray(
|
||||
[
|
||||
[0.2, 0.4, 0.3, 0.1, 0.1],
|
||||
[0.1, 0.6, 0.2, 0.4, 0.9],
|
||||
[0.8, 0.7, 0.3, 0.9, 0.1],
|
||||
],
|
||||
dtype="f",
|
||||
)
|
||||
spangroup_multi = spancat_multi._make_span_group_multilabel(doc, indices, scores)
|
||||
spangroup_single = spancat_single._make_span_group_singlelabel(doc, indices, scores)
|
||||
assert len(spangroup_single) == 2
|
||||
assert spangroup_single[0].text == "Greater"
|
||||
assert spangroup_single[0].label_ == "City"
|
||||
assert_almost_equal(0.4, spangroup_single.attrs["scores"][0], 5)
|
||||
assert spangroup_single[1].text == "Greater London"
|
||||
assert spangroup_single[1].label_ == "GreatCity"
|
||||
assert spangroup_single.attrs["scores"][1] == 0.9
|
||||
assert_almost_equal(0.9, spangroup_single.attrs["scores"][1], 5)
|
||||
|
||||
assert len(spangroup_multi) == 6
|
||||
assert spangroup_multi[0].text == "Greater"
|
||||
assert spangroup_multi[0].label_ == "City"
|
||||
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][0], 5)
|
||||
assert spangroup_multi[1].text == "Greater"
|
||||
assert spangroup_multi[1].label_ == "Person"
|
||||
assert_almost_equal(0.3, spangroup_multi.attrs["scores"][1], 5)
|
||||
assert spangroup_multi[2].text == "London"
|
||||
assert spangroup_multi[2].label_ == "City"
|
||||
assert_almost_equal(0.6, spangroup_multi.attrs["scores"][2], 5)
|
||||
assert spangroup_multi[3].text == "London"
|
||||
assert spangroup_multi[3].label_ == "GreatCity"
|
||||
assert_almost_equal(0.4, spangroup_multi.attrs["scores"][3], 5)
|
||||
assert spangroup_multi[4].text == "Greater London"
|
||||
assert spangroup_multi[4].label_ == "Thing"
|
||||
assert spangroup_multi[4].text == "Greater London"
|
||||
assert_almost_equal(0.8, spangroup_multi.attrs["scores"][4], 5)
|
||||
assert spangroup_multi[5].text == "Greater London"
|
||||
assert spangroup_multi[5].label_ == "GreatCity"
|
||||
assert_almost_equal(0.9, spangroup_multi.attrs["scores"][5], 5)
|
||||
|
||||
|
||||
def test_ngram_suggester(en_tokenizer):
|
||||
# test different n-gram lengths
|
||||
for size in [1, 2, 3]:
|
||||
|
@ -371,9 +503,9 @@ def test_overfitting_IO_overlapping():
|
|||
assert set([span.label_ for span in spans2]) == {"LOC", "DOUBLE_LOC"}
|
||||
|
||||
|
||||
def test_zero_suggestions():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_zero_suggestions(name):
|
||||
# Test with a suggester that can return 0 suggestions
|
||||
|
||||
@registry.misc("test_mixed_zero_suggester")
|
||||
def make_mixed_zero_suggester():
|
||||
def mixed_zero_suggester(docs, *, ops=None):
|
||||
|
@ -400,7 +532,7 @@ def test_zero_suggestions():
|
|||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
spancat = nlp.add_pipe(
|
||||
"spancat",
|
||||
name,
|
||||
config={
|
||||
"suggester": {"@misc": "test_mixed_zero_suggester"},
|
||||
"spans_key": SPAN_KEY,
|
||||
|
@ -408,7 +540,7 @@ def test_zero_suggestions():
|
|||
)
|
||||
train_examples = make_examples(nlp)
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
assert spancat.model.get_dim("nO") == 2
|
||||
assert spancat.model.get_dim("nO") == spancat._n_labels
|
||||
assert set(spancat.labels) == {"LOC", "PERSON"}
|
||||
|
||||
nlp.update(train_examples, sgd=optimizer)
|
||||
|
@ -424,9 +556,10 @@ def test_zero_suggestions():
|
|||
list(nlp.pipe(["", "one", "three three three"]))
|
||||
|
||||
|
||||
def test_set_candidates():
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
def test_set_candidates(name):
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
texts = [
|
||||
|
@ -444,3 +577,21 @@ def test_set_candidates():
|
|||
assert len(docs[0].spans["candidates"]) == 9
|
||||
assert docs[0].spans["candidates"][0].text == "Just"
|
||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_spancat_multiprocessing(name, n_process):
|
||||
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
texts = [
|
||||
"Just a sentence.",
|
||||
"I like London and Berlin",
|
||||
"I like Berlin",
|
||||
"I eat ham.",
|
||||
]
|
||||
docs = list(nlp.pipe(texts, n_process=n_process))
|
||||
assert len(docs) == len(texts)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from numpy.testing import assert_equal, assert_almost_equal
|
||||
from spacy.attrs import TAG
|
||||
|
||||
from spacy import util
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from thinc.api import compounding
|
||||
from thinc.api import compounding, get_current_ops
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -67,6 +67,30 @@ PARTIAL_DATA = [
|
|||
]
|
||||
|
||||
|
||||
def test_label_smoothing():
|
||||
nlp = Language()
|
||||
tagger_no_ls = nlp.add_pipe("tagger", "no_label_smoothing")
|
||||
tagger_ls = nlp.add_pipe(
|
||||
"tagger", "label_smoothing", config=dict(label_smoothing=0.05)
|
||||
)
|
||||
train_examples = []
|
||||
losses = {}
|
||||
for tag in TAGS:
|
||||
tagger_no_ls.add_label(tag)
|
||||
tagger_ls.add_label(tag)
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
ops = get_current_ops()
|
||||
no_ls_grads = ops.to_numpy(tagger_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
ls_grads = ops.to_numpy(tagger_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
assert_almost_equal(ls_grads / no_ls_grads, 0.925)
|
||||
|
||||
|
||||
def test_no_label():
|
||||
nlp = Language()
|
||||
nlp.add_pipe("tagger")
|
||||
|
|
|
@ -895,3 +895,26 @@ def test_textcat_multi_threshold():
|
|||
|
||||
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
|
||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"component_name,scorer",
|
||||
[
|
||||
("textcat", "spacy.textcat_scorer.v1"),
|
||||
("textcat_multilabel", "spacy.textcat_multilabel_scorer.v1"),
|
||||
],
|
||||
)
|
||||
def test_textcat_legacy_scorers(component_name, scorer):
|
||||
"""Check that legacy scorers are registered and produce the expected score
|
||||
keys."""
|
||||
nlp = English()
|
||||
nlp.add_pipe(component_name, config={"scorer": {"@scorers": scorer}})
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
# score the model (it's not actually trained but that doesn't matter)
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert 0 <= scores["cats_score"] <= 1
|
||||
|
|
|
@ -213,6 +213,13 @@ def test_serialize_doc_exclude(en_vocab):
|
|||
|
||||
def test_serialize_doc_span_groups(en_vocab):
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!"])
|
||||
doc.spans["content"] = [doc[0:2]]
|
||||
span = doc[0:2]
|
||||
span.label_ = "test_serialize_doc_span_groups_label"
|
||||
span.id_ = "test_serialize_doc_span_groups_id"
|
||||
span.kb_id_ = "test_serialize_doc_span_groups_kb_id"
|
||||
doc.spans["content"] = [span]
|
||||
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
|
||||
assert len(new_doc.spans["content"]) == 1
|
||||
assert new_doc.spans["content"][0].label_ == "test_serialize_doc_span_groups_label"
|
||||
assert new_doc.spans["content"][0].id_ == "test_serialize_doc_span_groups_id"
|
||||
assert new_doc.spans["content"][0].kb_id_ == "test_serialize_doc_span_groups_kb_id"
|
||||
|
|
|
@ -49,7 +49,11 @@ def test_serialize_doc_bin():
|
|||
nlp = English()
|
||||
for doc in nlp.pipe(texts):
|
||||
doc.cats = cats
|
||||
doc.spans["start"] = [doc[0:2]]
|
||||
span = doc[0:2]
|
||||
span.label_ = "UNUSUAL_SPAN_LABEL"
|
||||
span.id_ = "UNUSUAL_SPAN_ID"
|
||||
span.kb_id_ = "UNUSUAL_SPAN_KB_ID"
|
||||
doc.spans["start"] = [span]
|
||||
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
|
||||
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
|
||||
doc_bin.add(doc)
|
||||
|
@ -63,6 +67,9 @@ def test_serialize_doc_bin():
|
|||
assert doc.text == texts[i]
|
||||
assert doc.cats == cats
|
||||
assert len(doc.spans) == 1
|
||||
assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL"
|
||||
assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID"
|
||||
assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID"
|
||||
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
|
||||
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Any, Dict
|
||||
|
||||
from spacy import util
|
||||
from spacy.util import ensure_path, registry, load_model_from_config
|
||||
import srsly
|
||||
|
||||
from spacy import util, Errors
|
||||
from spacy.util import ensure_path, registry, load_model_from_config, SimpleFrozenList
|
||||
from spacy.kb.kb_in_memory import InMemoryLookupKB
|
||||
from spacy.vocab import Vocab
|
||||
from thinc.api import Config
|
||||
|
@ -91,7 +94,10 @@ def test_serialize_subclassed_kb():
|
|||
|
||||
[components.entity_linker]
|
||||
factory = "entity_linker"
|
||||
|
||||
|
||||
[components.entity_linker.generate_empty_kb]
|
||||
@misc = "kb_test.CustomEmptyKB.v1"
|
||||
|
||||
[initialize]
|
||||
|
||||
[initialize.components]
|
||||
|
@ -99,7 +105,7 @@ def test_serialize_subclassed_kb():
|
|||
[initialize.components.entity_linker]
|
||||
|
||||
[initialize.components.entity_linker.kb_loader]
|
||||
@misc = "spacy.CustomKB.v1"
|
||||
@misc = "kb_test.CustomKB.v1"
|
||||
entity_vector_length = 342
|
||||
custom_field = 666
|
||||
"""
|
||||
|
@ -109,10 +115,57 @@ def test_serialize_subclassed_kb():
|
|||
super().__init__(vocab, entity_vector_length)
|
||||
self.custom_field = custom_field
|
||||
|
||||
@registry.misc("spacy.CustomKB.v1")
|
||||
def to_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||
"""We overwrite InMemoryLookupKB.to_disk() to ensure that self.custom_field is stored as well."""
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True)
|
||||
if not path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
|
||||
def serialize_custom_fields(file_path: Path) -> None:
|
||||
srsly.write_json(file_path, {"custom_field": self.custom_field})
|
||||
|
||||
serialize = {
|
||||
"contents": lambda p: self.write_contents(p),
|
||||
"strings.json": lambda p: self.vocab.strings.to_disk(p),
|
||||
"custom_fields": lambda p: serialize_custom_fields(p),
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude: Iterable[str] = SimpleFrozenList()):
|
||||
"""We overwrite InMemoryLookupKB.from_disk() to ensure that self.custom_field is loaded as well."""
|
||||
path = ensure_path(path)
|
||||
if not path.exists():
|
||||
raise ValueError(Errors.E929.format(loc=path))
|
||||
if not path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
|
||||
def deserialize_custom_fields(file_path: Path) -> None:
|
||||
self.custom_field = srsly.read_json(file_path)["custom_field"]
|
||||
|
||||
deserialize: Dict[str, Callable[[Any], Any]] = {
|
||||
"contents": lambda p: self.read_contents(p),
|
||||
"strings.json": lambda p: self.vocab.strings.from_disk(p),
|
||||
"custom_fields": lambda p: deserialize_custom_fields(p),
|
||||
}
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
|
||||
@registry.misc("kb_test.CustomEmptyKB.v1")
|
||||
def empty_custom_kb() -> Callable[[Vocab, int], SubInMemoryLookupKB]:
|
||||
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
|
||||
return SubInMemoryLookupKB(
|
||||
vocab=vocab,
|
||||
entity_vector_length=entity_vector_length,
|
||||
custom_field=0,
|
||||
)
|
||||
|
||||
return empty_kb_factory
|
||||
|
||||
@registry.misc("kb_test.CustomKB.v1")
|
||||
def custom_kb(
|
||||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[[Vocab], InMemoryLookupKB]:
|
||||
) -> Callable[[Vocab], SubInMemoryLookupKB]:
|
||||
def custom_kb_factory(vocab):
|
||||
kb = SubInMemoryLookupKB(
|
||||
vocab=vocab,
|
||||
|
@ -139,6 +192,6 @@ def test_serialize_subclassed_kb():
|
|||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||
# After IO, the KB is the standard one
|
||||
assert type(entity_linker2.kb) == InMemoryLookupKB
|
||||
assert type(entity_linker2.kb) == SubInMemoryLookupKB
|
||||
assert entity_linker2.kb.entity_vector_length == 342
|
||||
assert not hasattr(entity_linker2.kb, "custom_field")
|
||||
assert entity_linker2.kb.custom_field == 666
|
||||
|
|
|
@ -2,9 +2,10 @@ import os
|
|||
import math
|
||||
from collections import Counter
|
||||
from typing import Tuple, List, Dict, Any
|
||||
import pkg_resources
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import spacy
|
||||
import numpy
|
||||
import pytest
|
||||
import srsly
|
||||
|
@ -14,7 +15,7 @@ from thinc.api import Config, ConfigValidationError
|
|||
|
||||
from spacy import about
|
||||
from spacy.cli import info
|
||||
from spacy.cli._util import is_subpath_of, load_project_config
|
||||
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
|
||||
from spacy.cli._util import parse_config_overrides, string_to_list
|
||||
from spacy.cli._util import substitute_project_variables
|
||||
from spacy.cli._util import validate_project_commands
|
||||
|
@ -27,11 +28,13 @@ from spacy.cli.debug_data import _print_span_characteristics
|
|||
from spacy.cli.debug_data import _get_spans_length_freq_dist
|
||||
from spacy.cli.download import get_compatibility, get_version
|
||||
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||
from spacy.cli.init_pipeline import _init_labels
|
||||
from spacy.cli.package import get_third_party_dependencies
|
||||
from spacy.cli.package import _is_permitted_package_name
|
||||
from spacy.cli.project.remote_storage import RemoteStorage
|
||||
from spacy.cli.project.run import _check_requirements
|
||||
from spacy.cli.validate import get_model_pkgs
|
||||
from spacy.cli.apply import apply
|
||||
from spacy.cli.find_threshold import find_threshold
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
|
@ -44,7 +47,6 @@ from spacy.training.converters import conll_ner_to_docs, conllu_to_docs
|
|||
from spacy.training.converters import iob_to_docs
|
||||
from spacy.util import ENV_VARS, get_minor_version, load_model_from_config, load_config
|
||||
|
||||
from ..cli.init_pipeline import _init_labels
|
||||
from .util import make_tempdir
|
||||
|
||||
|
||||
|
@ -550,7 +552,14 @@ def test_parse_cli_overrides():
|
|||
|
||||
@pytest.mark.parametrize("lang", ["en", "nl"])
|
||||
@pytest.mark.parametrize(
|
||||
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]
|
||||
"pipeline",
|
||||
[
|
||||
["tagger", "parser", "ner"],
|
||||
[],
|
||||
["ner", "textcat", "sentencizer"],
|
||||
["morphologizer", "spancat", "entity_linker"],
|
||||
["spancat_singlelabel", "textcat_multilabel"],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
|
||||
@pytest.mark.parametrize("pretraining", [True, False])
|
||||
|
@ -615,7 +624,6 @@ def test_string_to_list_intify(value):
|
|||
assert string_to_list(value, intify=True) == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily skip for dev version")
|
||||
def test_download_compatibility():
|
||||
spec = SpecifierSet("==" + about.__version__)
|
||||
spec.prereleases = False
|
||||
|
@ -626,7 +634,6 @@ def test_download_compatibility():
|
|||
assert get_minor_version(about.__version__) == get_minor_version(version)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily skip for dev version")
|
||||
def test_validate_compatibility_table():
|
||||
spec = SpecifierSet("==" + about.__version__)
|
||||
spec.prereleases = False
|
||||
|
@ -885,6 +892,82 @@ def test_span_length_freq_dist_output_must_be_correct():
|
|||
assert list(span_freqs.keys()) == [3, 1, 4, 5, 2]
|
||||
|
||||
|
||||
def test_applycli_empty_dir():
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "test.spacy"
|
||||
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||
|
||||
|
||||
def test_applycli_docbin():
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "testout.spacy"
|
||||
nlp = spacy.blank("en")
|
||||
doc = nlp("testing apply cli.")
|
||||
# test empty DocBin case
|
||||
docbin = DocBin()
|
||||
docbin.to_disk(data_path / "testin.spacy")
|
||||
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||
docbin.add(doc)
|
||||
docbin.to_disk(data_path / "testin.spacy")
|
||||
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||
|
||||
|
||||
def test_applycli_jsonl():
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "testout.spacy"
|
||||
data = [{"field": "Testing apply cli.", "key": 234}]
|
||||
data2 = [{"field": "234"}]
|
||||
srsly.write_jsonl(data_path / "test.jsonl", data)
|
||||
apply(data_path, output, "blank:en", "field", 1, 1)
|
||||
srsly.write_jsonl(data_path / "test2.jsonl", data2)
|
||||
apply(data_path, output, "blank:en", "field", 1, 1)
|
||||
|
||||
|
||||
def test_applycli_txt():
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "testout.spacy"
|
||||
with open(data_path / "test.foo", "w") as ftest:
|
||||
ftest.write("Testing apply cli.")
|
||||
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||
|
||||
|
||||
def test_applycli_mixed():
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "testout.spacy"
|
||||
text = "Testing apply cli"
|
||||
nlp = spacy.blank("en")
|
||||
doc = nlp(text)
|
||||
jsonl_data = [{"text": text}]
|
||||
srsly.write_jsonl(data_path / "test.jsonl", jsonl_data)
|
||||
docbin = DocBin()
|
||||
docbin.add(doc)
|
||||
docbin.to_disk(data_path / "testin.spacy")
|
||||
with open(data_path / "test.txt", "w") as ftest:
|
||||
ftest.write(text)
|
||||
apply(data_path, output, "blank:en", "text", 1, 1)
|
||||
# Check whether it worked
|
||||
result = list(DocBin().from_disk(output).get_docs(nlp.vocab))
|
||||
assert len(result) == 3
|
||||
for doc in result:
|
||||
assert doc.text == text
|
||||
|
||||
|
||||
def test_applycli_user_data():
|
||||
Doc.set_extension("ext", default=0)
|
||||
val = ("ext", 0)
|
||||
with make_tempdir() as data_path:
|
||||
output = data_path / "testout.spacy"
|
||||
nlp = spacy.blank("en")
|
||||
doc = nlp("testing apply cli.")
|
||||
doc._.ext = val
|
||||
docbin = DocBin(store_user_data=True)
|
||||
docbin.add(doc)
|
||||
docbin.to_disk(data_path / "testin.spacy")
|
||||
apply(data_path, output, "blank:en", "", 1, 1)
|
||||
result = list(DocBin().from_disk(output).get_docs(nlp.vocab))
|
||||
assert result[0]._.ext == val
|
||||
|
||||
|
||||
def test_local_remote_storage():
|
||||
with make_tempdir() as d:
|
||||
filename = "a.txt"
|
||||
|
@ -940,8 +1023,6 @@ def test_local_remote_storage_pull_missing():
|
|||
|
||||
|
||||
def test_cli_find_threshold(capsys):
|
||||
thresholds = numpy.linspace(0, 1, 10)
|
||||
|
||||
def make_examples(nlp: Language) -> List[Example]:
|
||||
docs: List[Example] = []
|
||||
|
||||
|
@ -997,7 +1078,7 @@ def test_cli_find_threshold(capsys):
|
|||
)
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
res = find_threshold(
|
||||
best_threshold, best_score, res = find_threshold(
|
||||
model=nlp_dir,
|
||||
data_path=docs_dir / "docs.spacy",
|
||||
pipe_name="tc_multi",
|
||||
|
@ -1005,16 +1086,14 @@ def test_cli_find_threshold(capsys):
|
|||
scores_key="cats_macro_f",
|
||||
silent=True,
|
||||
)
|
||||
assert res[0] != thresholds[0]
|
||||
assert thresholds[0] < res[0] < thresholds[9]
|
||||
assert res[1] == 1.0
|
||||
assert res[2][1.0] == 0.0
|
||||
assert best_score == max(res.values())
|
||||
assert res[1.0] == 0.0
|
||||
|
||||
# Test with spancat.
|
||||
nlp, _ = init_nlp((("spancat", {}),))
|
||||
with make_tempdir() as nlp_dir:
|
||||
nlp.to_disk(nlp_dir)
|
||||
res = find_threshold(
|
||||
best_threshold, best_score, res = find_threshold(
|
||||
model=nlp_dir,
|
||||
data_path=docs_dir / "docs.spacy",
|
||||
pipe_name="spancat",
|
||||
|
@ -1022,10 +1101,8 @@ def test_cli_find_threshold(capsys):
|
|||
scores_key="spans_sc_f",
|
||||
silent=True,
|
||||
)
|
||||
assert res[0] != thresholds[0]
|
||||
assert thresholds[0] < res[0] < thresholds[8]
|
||||
assert res[1] >= 0.6
|
||||
assert res[2][1.0] == 0.0
|
||||
assert best_score == max(res.values())
|
||||
assert res[1.0] == 0.0
|
||||
|
||||
# Having multiple textcat_multilabel components should work, since the name has to be specified.
|
||||
nlp, _ = init_nlp((("textcat_multilabel", {}),))
|
||||
|
@ -1055,6 +1132,7 @@ def test_cli_find_threshold(capsys):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
@pytest.mark.parametrize(
|
||||
"reqs,output",
|
||||
[
|
||||
|
@ -1087,6 +1165,8 @@ def test_cli_find_threshold(capsys):
|
|||
],
|
||||
)
|
||||
def test_project_check_requirements(reqs, output):
|
||||
import pkg_resources
|
||||
|
||||
# excessive guard against unlikely package name
|
||||
try:
|
||||
pkg_resources.require("spacyunknowndoesnotexist12345")
|
||||
|
@ -1107,3 +1187,92 @@ def test_upload_download_local_file():
|
|||
download_file(remote_file, local_file)
|
||||
with local_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_walk_directory():
|
||||
with make_tempdir() as d:
|
||||
files = [
|
||||
"data1.iob",
|
||||
"data2.iob",
|
||||
"data3.json",
|
||||
"data4.conll",
|
||||
"data5.conll",
|
||||
"data6.conll",
|
||||
"data7.txt",
|
||||
]
|
||||
|
||||
for f in files:
|
||||
Path(d / f).touch()
|
||||
|
||||
assert (len(walk_directory(d))) == 7
|
||||
assert (len(walk_directory(d, suffix=None))) == 7
|
||||
assert (len(walk_directory(d, suffix="json"))) == 1
|
||||
assert (len(walk_directory(d, suffix="iob"))) == 2
|
||||
assert (len(walk_directory(d, suffix="conll"))) == 3
|
||||
assert (len(walk_directory(d, suffix="pdf"))) == 0
|
||||
|
||||
|
||||
def test_debug_data_trainable_lemmatizer_basic():
|
||||
examples = [
|
||||
("She likes green eggs", {"lemmas": ["she", "like", "green", "egg"]}),
|
||||
("Eat blue ham", {"lemmas": ["eat", "blue", "ham"]}),
|
||||
]
|
||||
nlp = Language()
|
||||
train_examples = []
|
||||
for t in examples:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
|
||||
# ref test_edit_tree_lemmatizer::test_initialize_from_labels
|
||||
# this results in 4 trees
|
||||
assert len(data["lemmatizer_trees"]) == 4
|
||||
|
||||
|
||||
def test_debug_data_trainable_lemmatizer_partial():
|
||||
partial_examples = [
|
||||
# partial annotation
|
||||
("She likes green eggs", {"lemmas": ["", "like", "green", ""]}),
|
||||
# misaligned partial annotation
|
||||
(
|
||||
"He hates green eggs",
|
||||
{
|
||||
"words": ["He", "hat", "es", "green", "eggs"],
|
||||
"lemmas": ["", "hat", "e", "green", ""],
|
||||
},
|
||||
),
|
||||
]
|
||||
nlp = Language()
|
||||
train_examples = []
|
||||
for t in partial_examples:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
|
||||
assert data["partial_lemma_annotations"] == 2
|
||||
|
||||
|
||||
def test_debug_data_trainable_lemmatizer_low_cardinality():
|
||||
low_cardinality_examples = [
|
||||
("She likes green eggs", {"lemmas": ["no", "no", "no", "no"]}),
|
||||
("Eat blue ham", {"lemmas": ["no", "no", "no"]}),
|
||||
]
|
||||
nlp = Language()
|
||||
train_examples = []
|
||||
for t in low_cardinality_examples:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
|
||||
assert data["n_low_cardinality_lemmas"] == 2
|
||||
|
||||
|
||||
def test_debug_data_trainable_lemmatizer_not_annotated():
|
||||
unannotated_examples = [
|
||||
("She likes green eggs", {}),
|
||||
("Eat blue ham", {}),
|
||||
]
|
||||
nlp = Language()
|
||||
train_examples = []
|
||||
for t in unannotated_examples:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
data = _compile_gold(train_examples, ["trainable_lemmatizer"], nlp, True)
|
||||
assert data["no_lemma_annotations"] == 2
|
||||
|
|
237
spacy/tests/test_cli_app.py
Normal file
237
spacy/tests/test_cli_app.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import srsly
|
||||
from typer.testing import CliRunner
|
||||
from spacy.tokens import DocBin, Doc
|
||||
|
||||
from spacy.cli._util import app, get_git_version
|
||||
from .util import make_tempdir, normalize_whitespace
|
||||
|
||||
|
||||
def has_git():
|
||||
try:
|
||||
get_git_version()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def test_convert_auto():
|
||||
with make_tempdir() as d_in, make_tempdir() as d_out:
|
||||
for f in ["data1.iob", "data2.iob", "data3.iob"]:
|
||||
Path(d_in / f).touch()
|
||||
|
||||
# ensure that "automatic" suffix detection works
|
||||
result = CliRunner().invoke(app, ["convert", str(d_in), str(d_out)])
|
||||
assert "Generated output file" in result.stdout
|
||||
out_files = os.listdir(d_out)
|
||||
assert len(out_files) == 3
|
||||
assert "data1.spacy" in out_files
|
||||
assert "data2.spacy" in out_files
|
||||
assert "data3.spacy" in out_files
|
||||
|
||||
|
||||
def test_convert_auto_conflict():
|
||||
with make_tempdir() as d_in, make_tempdir() as d_out:
|
||||
for f in ["data1.iob", "data2.iob", "data3.json"]:
|
||||
Path(d_in / f).touch()
|
||||
|
||||
# ensure that "automatic" suffix detection warns when there are different file types
|
||||
result = CliRunner().invoke(app, ["convert", str(d_in), str(d_out)])
|
||||
assert "All input files must be same type" in result.stdout
|
||||
out_files = os.listdir(d_out)
|
||||
assert len(out_files) == 0
|
||||
|
||||
|
||||
def test_benchmark_accuracy_alias():
|
||||
# Verify that the `evaluate` alias works correctly.
|
||||
result_benchmark = CliRunner().invoke(app, ["benchmark", "accuracy", "--help"])
|
||||
result_evaluate = CliRunner().invoke(app, ["evaluate", "--help"])
|
||||
assert normalize_whitespace(result_benchmark.stdout) == normalize_whitespace(
|
||||
result_evaluate.stdout.replace("spacy evaluate", "spacy benchmark accuracy")
|
||||
)
|
||||
|
||||
|
||||
def test_debug_data_trainable_lemmatizer_cli(en_vocab):
|
||||
train_docs = [
|
||||
Doc(en_vocab, words=["I", "like", "cats"], lemmas=["I", "like", "cat"]),
|
||||
Doc(
|
||||
en_vocab,
|
||||
words=["Dogs", "are", "great", "too"],
|
||||
lemmas=["dog", "be", "great", "too"],
|
||||
),
|
||||
]
|
||||
dev_docs = [
|
||||
Doc(en_vocab, words=["Cats", "are", "cute"], lemmas=["cat", "be", "cute"]),
|
||||
Doc(en_vocab, words=["Pets", "are", "great"], lemmas=["pet", "be", "great"]),
|
||||
]
|
||||
with make_tempdir() as d_in:
|
||||
train_bin = DocBin(docs=train_docs)
|
||||
train_bin.to_disk(d_in / "train.spacy")
|
||||
dev_bin = DocBin(docs=dev_docs)
|
||||
dev_bin.to_disk(d_in / "dev.spacy")
|
||||
# `debug data` requires an input pipeline config
|
||||
CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"init",
|
||||
"config",
|
||||
f"{d_in}/config.cfg",
|
||||
"--lang",
|
||||
"en",
|
||||
"--pipeline",
|
||||
"trainable_lemmatizer",
|
||||
],
|
||||
)
|
||||
result_debug_data = CliRunner().invoke(
|
||||
app,
|
||||
[
|
||||
"debug",
|
||||
"data",
|
||||
f"{d_in}/config.cfg",
|
||||
"--paths.train",
|
||||
f"{d_in}/train.spacy",
|
||||
"--paths.dev",
|
||||
f"{d_in}/dev.spacy",
|
||||
],
|
||||
)
|
||||
# Instead of checking specific wording of the output, which may change,
|
||||
# we'll check that this section of the debug output is present.
|
||||
assert "= Trainable Lemmatizer =" in result_debug_data.stdout
|
||||
|
||||
|
||||
# project tests
|
||||
|
||||
SAMPLE_PROJECT = {
|
||||
"title": "Sample project",
|
||||
"description": "This is a project for testing",
|
||||
"assets": [
|
||||
{
|
||||
"dest": "assets/spacy-readme.md",
|
||||
"url": "https://github.com/explosion/spaCy/raw/dec81508d28b47f09a06203c472b37f00db6c869/README.md",
|
||||
"checksum": "411b2c89ccf34288fae8ed126bf652f7",
|
||||
},
|
||||
{
|
||||
"dest": "assets/citation.cff",
|
||||
"url": "https://github.com/explosion/spaCy/raw/master/CITATION.cff",
|
||||
"checksum": "c996bfd80202d480eb2e592369714e5e",
|
||||
"extra": True,
|
||||
},
|
||||
],
|
||||
"commands": [
|
||||
{
|
||||
"name": "ok",
|
||||
"help": "print ok",
|
||||
"script": ["python -c \"print('okokok')\""],
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"help": "make a file",
|
||||
"script": ["touch abc.txt"],
|
||||
"outputs": ["abc.txt"],
|
||||
},
|
||||
{
|
||||
"name": "clean",
|
||||
"help": "remove test file",
|
||||
"script": ["rm abc.txt"],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
SAMPLE_PROJECT_TEXT = srsly.yaml_dumps(SAMPLE_PROJECT)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_dir():
|
||||
with make_tempdir() as pdir:
|
||||
(pdir / "project.yml").write_text(SAMPLE_PROJECT_TEXT)
|
||||
yield pdir
|
||||
|
||||
|
||||
def test_project_document(project_dir):
|
||||
readme_path = project_dir / "README.md"
|
||||
assert not readme_path.exists(), "README already exists"
|
||||
result = CliRunner().invoke(
|
||||
app, ["project", "document", str(project_dir), "-o", str(readme_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert readme_path.is_file()
|
||||
text = readme_path.read_text("utf-8")
|
||||
assert SAMPLE_PROJECT["description"] in text
|
||||
|
||||
|
||||
def test_project_assets(project_dir):
|
||||
asset_dir = project_dir / "assets"
|
||||
assert not asset_dir.exists(), "Assets dir is already present"
|
||||
result = CliRunner().invoke(app, ["project", "assets", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert (asset_dir / "spacy-readme.md").is_file(), "Assets not downloaded"
|
||||
# check that extras work
|
||||
result = CliRunner().invoke(app, ["project", "assets", "--extra", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert (asset_dir / "citation.cff").is_file(), "Extras not downloaded"
|
||||
|
||||
|
||||
def test_project_run(project_dir):
|
||||
# make sure dry run works
|
||||
test_file = project_dir / "abc.txt"
|
||||
result = CliRunner().invoke(
|
||||
app, ["project", "run", "--dry", "create", str(project_dir)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert not test_file.is_file()
|
||||
result = CliRunner().invoke(app, ["project", "run", "create", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert test_file.is_file()
|
||||
result = CliRunner().invoke(app, ["project", "run", "ok", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert "okokok" in result.stdout
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_git(), reason="git not installed")
|
||||
@pytest.mark.parametrize(
|
||||
"options",
|
||||
[
|
||||
"",
|
||||
# "--sparse",
|
||||
"--branch v3",
|
||||
"--repo https://github.com/explosion/projects --branch v3",
|
||||
],
|
||||
)
|
||||
def test_project_clone(options):
|
||||
with make_tempdir() as workspace:
|
||||
out = workspace / "project"
|
||||
target = "benchmarks/ner_conll03"
|
||||
if not options:
|
||||
options = []
|
||||
else:
|
||||
options = options.split()
|
||||
result = CliRunner().invoke(
|
||||
app, ["project", "clone", target, *options, str(out)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert (out / "README.md").is_file()
|
||||
|
||||
|
||||
def test_project_push_pull(project_dir):
|
||||
proj = dict(SAMPLE_PROJECT)
|
||||
remote = "xyz"
|
||||
|
||||
with make_tempdir() as remote_dir:
|
||||
proj["remotes"] = {remote: str(remote_dir)}
|
||||
proj_text = srsly.yaml_dumps(proj)
|
||||
(project_dir / "project.yml").write_text(proj_text)
|
||||
|
||||
test_file = project_dir / "abc.txt"
|
||||
result = CliRunner().invoke(app, ["project", "run", "create", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert test_file.is_file()
|
||||
result = CliRunner().invoke(app, ["project", "push", remote, str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
result = CliRunner().invoke(app, ["project", "run", "clean", str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert not test_file.exists()
|
||||
result = CliRunner().invoke(app, ["project", "pull", remote, str(project_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert test_file.is_file()
|
|
@ -275,6 +275,20 @@ def test_displacy_parse_deps(en_vocab):
|
|||
{"start": 2, "end": 3, "label": "det", "dir": "left"},
|
||||
{"start": 1, "end": 3, "label": "attr", "dir": "right"},
|
||||
]
|
||||
# Test that displacy.parse_deps converts Span to Doc
|
||||
deps = displacy.parse_deps(doc[:])
|
||||
assert isinstance(deps, dict)
|
||||
assert deps["words"] == [
|
||||
{"lemma": None, "text": words[0], "tag": pos[0]},
|
||||
{"lemma": None, "text": words[1], "tag": pos[1]},
|
||||
{"lemma": None, "text": words[2], "tag": pos[2]},
|
||||
{"lemma": None, "text": words[3], "tag": pos[3]},
|
||||
]
|
||||
assert deps["arcs"] == [
|
||||
{"start": 0, "end": 1, "label": "nsubj", "dir": "left"},
|
||||
{"start": 2, "end": 3, "label": "det", "dir": "left"},
|
||||
{"start": 1, "end": 3, "label": "attr", "dir": "right"},
|
||||
]
|
||||
|
||||
|
||||
def test_displacy_invalid_arcs():
|
||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from unittest import mock
|
||||
import pytest
|
||||
from spacy.language import Language
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.training import Example
|
||||
|
@ -45,7 +46,7 @@ def assert_sents_error(doc):
|
|||
|
||||
def warn_error(proc_name, proc, docs, e):
|
||||
logger = logging.getLogger("spacy")
|
||||
logger.warning(f"Trouble with component {proc_name}.")
|
||||
logger.warning("Trouble with component %s.", proc_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -126,6 +127,112 @@ def test_evaluate_no_pipe(nlp):
|
|||
nlp.evaluate([Example.from_dict(doc, annots)])
|
||||
|
||||
|
||||
def test_evaluate_textcat_multilabel(en_vocab):
|
||||
"""Test that evaluate works with a multilabel textcat pipe."""
|
||||
nlp = Language(en_vocab)
|
||||
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||
textcat_multilabel.add_label(label)
|
||||
nlp.initialize()
|
||||
|
||||
annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}
|
||||
doc = nlp.make_doc("hello world")
|
||||
example = Example.from_dict(doc, annots)
|
||||
scores = nlp.evaluate([example])
|
||||
labels = nlp.get_pipe("textcat_multilabel").labels
|
||||
for label in labels:
|
||||
assert scores["cats_f_per_type"].get(label) is not None
|
||||
for key in example.reference.cats.keys():
|
||||
if key not in labels:
|
||||
assert scores["cats_f_per_type"].get(key) is None
|
||||
|
||||
|
||||
def test_evaluate_multiple_textcat_final(en_vocab):
|
||||
"""Test that evaluate evaluates the final textcat component in a pipeline
|
||||
with more than one textcat or textcat_multilabel."""
|
||||
nlp = Language(en_vocab)
|
||||
textcat = nlp.add_pipe("textcat")
|
||||
for label in ("POSITIVE", "NEGATIVE"):
|
||||
textcat.add_label(label)
|
||||
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||
textcat_multilabel.add_label(label)
|
||||
nlp.initialize()
|
||||
|
||||
annots = {
|
||||
"cats": {
|
||||
"POSITIVE": 1.0,
|
||||
"NEGATIVE": 0.0,
|
||||
"FEATURE": 1.0,
|
||||
"QUESTION": 1.0,
|
||||
"POSITIVE": 1.0,
|
||||
"NEGATIVE": 0.0,
|
||||
}
|
||||
}
|
||||
doc = nlp.make_doc("hello world")
|
||||
example = Example.from_dict(doc, annots)
|
||||
scores = nlp.evaluate([example])
|
||||
# get the labels from the final pipe
|
||||
labels = nlp.get_pipe(nlp.pipe_names[-1]).labels
|
||||
for label in labels:
|
||||
assert scores["cats_f_per_type"].get(label) is not None
|
||||
for key in example.reference.cats.keys():
|
||||
if key not in labels:
|
||||
assert scores["cats_f_per_type"].get(key) is None
|
||||
|
||||
|
||||
def test_evaluate_multiple_textcat_separate(en_vocab):
|
||||
"""Test that evaluate can evaluate multiple textcat components separately
|
||||
with custom scorers."""
|
||||
|
||||
def custom_textcat_score(examples, **kwargs):
|
||||
scores = Scorer.score_cats(
|
||||
examples,
|
||||
"cats",
|
||||
multi_label=False,
|
||||
**kwargs,
|
||||
)
|
||||
return {f"custom_{k}": v for k, v in scores.items()}
|
||||
|
||||
@spacy.registry.scorers("test_custom_textcat_scorer")
|
||||
def make_custom_textcat_scorer():
|
||||
return custom_textcat_score
|
||||
|
||||
nlp = Language(en_vocab)
|
||||
textcat = nlp.add_pipe(
|
||||
"textcat",
|
||||
config={"scorer": {"@scorers": "test_custom_textcat_scorer"}},
|
||||
)
|
||||
for label in ("POSITIVE", "NEGATIVE"):
|
||||
textcat.add_label(label)
|
||||
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||
textcat_multilabel.add_label(label)
|
||||
nlp.initialize()
|
||||
|
||||
annots = {
|
||||
"cats": {
|
||||
"POSITIVE": 1.0,
|
||||
"NEGATIVE": 0.0,
|
||||
"FEATURE": 1.0,
|
||||
"QUESTION": 1.0,
|
||||
"POSITIVE": 1.0,
|
||||
"NEGATIVE": 0.0,
|
||||
}
|
||||
}
|
||||
doc = nlp.make_doc("hello world")
|
||||
example = Example.from_dict(doc, annots)
|
||||
scores = nlp.evaluate([example])
|
||||
# check custom scores for the textcat pipe
|
||||
assert "custom_cats_f_per_type" in scores
|
||||
labels = nlp.get_pipe("textcat").labels
|
||||
assert set(scores["custom_cats_f_per_type"].keys()) == set(labels)
|
||||
# check default scores for the textcat_multilabel pipe
|
||||
assert "cats_f_per_type" in scores
|
||||
labels = nlp.get_pipe("textcat_multilabel").labels
|
||||
assert set(scores["cats_f_per_type"].keys()) == set(labels)
|
||||
|
||||
|
||||
def vector_modification_pipe(doc):
|
||||
doc.vector += 1
|
||||
return doc
|
||||
|
|
|
@ -8,7 +8,7 @@ from spacy import prefer_gpu, require_gpu, require_cpu
|
|||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
||||
from spacy.util import dot_to_object, SimpleFrozenList, import_file
|
||||
from spacy.util import to_ternary_int
|
||||
from spacy.util import to_ternary_int, find_available_port
|
||||
from thinc.api import Config, Optimizer, ConfigValidationError
|
||||
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
|
||||
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
|
||||
|
@ -434,3 +434,16 @@ def test_to_ternary_int():
|
|||
assert to_ternary_int(-10) == -1
|
||||
assert to_ternary_int("string") == -1
|
||||
assert to_ternary_int([0, "string"]) == -1
|
||||
|
||||
|
||||
def test_find_available_port():
|
||||
host = "0.0.0.0"
|
||||
port = 5000
|
||||
assert find_available_port(port, host) == port, "Port 5000 isn't free"
|
||||
|
||||
from wsgiref.simple_server import make_server, demo_app
|
||||
|
||||
with make_server(host, port, demo_app) as httpd:
|
||||
with pytest.warns(UserWarning, match="already in use"):
|
||||
found_port = find_available_port(port, host, auto_select=True)
|
||||
assert found_port == port + 1, "Didn't find next port"
|
||||
|
|
|
@ -110,7 +110,7 @@ def test_tokenization(sented_doc):
|
|||
)
|
||||
example.predicted[1].is_sent_start = False
|
||||
scores = scorer.score([example])
|
||||
assert scores["token_acc"] == approx(0.66666666)
|
||||
assert scores["token_acc"] == 0.5
|
||||
assert scores["token_p"] == 0.5
|
||||
assert scores["token_r"] == approx(0.33333333)
|
||||
assert scores["token_f"] == 0.4
|
||||
|
|
78
spacy/tests/training/test_corpus.py
Normal file
78
spacy/tests/training/test_corpus.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from typing import IO, Generator, Iterable, List, TextIO, Tuple
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import tempfile
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.training import Example, PlainTextCorpus
|
||||
from spacy.util import make_tempdir
|
||||
|
||||
# Intentional newlines to check that they are skipped.
|
||||
PLAIN_TEXT_DOC = """
|
||||
|
||||
This is a doc. It contains two sentences.
|
||||
This is another doc.
|
||||
|
||||
A third doc.
|
||||
|
||||
"""
|
||||
|
||||
PLAIN_TEXT_DOC_TOKENIZED = [
|
||||
[
|
||||
"This",
|
||||
"is",
|
||||
"a",
|
||||
"doc",
|
||||
".",
|
||||
"It",
|
||||
"contains",
|
||||
"two",
|
||||
"sentences",
|
||||
".",
|
||||
],
|
||||
["This", "is", "another", "doc", "."],
|
||||
["A", "third", "doc", "."],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_length", [0, 5])
|
||||
@pytest.mark.parametrize("max_length", [0, 5])
|
||||
def test_plain_text_reader(min_length, max_length):
|
||||
nlp = English()
|
||||
with _string_to_tmp_file(PLAIN_TEXT_DOC) as file_path:
|
||||
corpus = PlainTextCorpus(
|
||||
file_path, min_length=min_length, max_length=max_length
|
||||
)
|
||||
|
||||
check = [
|
||||
doc
|
||||
for doc in PLAIN_TEXT_DOC_TOKENIZED
|
||||
if len(doc) >= min_length and (max_length == 0 or len(doc) <= max_length)
|
||||
]
|
||||
reference, predicted = _examples_to_tokens(corpus(nlp))
|
||||
|
||||
assert reference == check
|
||||
assert predicted == check
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _string_to_tmp_file(s: str) -> Generator[Path, None, None]:
|
||||
with make_tempdir() as d:
|
||||
file_path = Path(d) / "string.txt"
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(s)
|
||||
yield file_path
|
||||
|
||||
|
||||
def _examples_to_tokens(
|
||||
examples: Iterable[Example],
|
||||
) -> Tuple[List[List[str]], List[List[str]]]:
|
||||
reference = []
|
||||
predicted = []
|
||||
|
||||
for eg in examples:
|
||||
reference.append([t.text for t in eg.reference])
|
||||
predicted.append([t.text for t in eg.predicted])
|
||||
|
||||
return reference, predicted
|
|
@ -2,17 +2,19 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import pytest
|
||||
import srsly
|
||||
from spacy.vocab import Vocab
|
||||
from thinc.api import Config
|
||||
from thinc.api import Config, get_current_ops
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.training.initialize import init_nlp
|
||||
from spacy.training.loop import train
|
||||
from spacy.training.pretrain import pretrain
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
|
||||
from spacy.ml.models.multi_task import create_pretrain_vectors
|
||||
from spacy.vectors import Vectors
|
||||
from spacy.vocab import Vocab
|
||||
from ..util import make_tempdir
|
||||
from ... import util
|
||||
from ...lang.en import English
|
||||
from ...training.initialize import init_nlp
|
||||
from ...training.loop import train
|
||||
from ...training.pretrain import pretrain
|
||||
from ...tokens import Doc, DocBin
|
||||
from ...language import DEFAULT_CONFIG_PRETRAIN_PATH, DEFAULT_CONFIG_PATH
|
||||
|
||||
pretrain_string_listener = """
|
||||
[nlp]
|
||||
|
@ -163,7 +165,8 @@ def test_pretraining_default():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("objective", CHAR_OBJECTIVES)
|
||||
def test_pretraining_tok2vec_characters(objective):
|
||||
@pytest.mark.parametrize("skip_last", (True, False))
|
||||
def test_pretraining_tok2vec_characters(objective, skip_last):
|
||||
"""Test that pretraining works with the character objective"""
|
||||
config = Config().from_str(pretrain_string_listener)
|
||||
config["pretraining"]["objective"] = objective
|
||||
|
@ -176,10 +179,14 @@ def test_pretraining_tok2vec_characters(objective):
|
|||
filled["paths"]["raw_text"] = file_path
|
||||
filled = filled.interpolate()
|
||||
assert filled["pretraining"]["component"] == "tok2vec"
|
||||
pretrain(filled, tmp_dir)
|
||||
pretrain(filled, tmp_dir, skip_last=skip_last)
|
||||
assert Path(tmp_dir / "model0.bin").exists()
|
||||
assert Path(tmp_dir / "model4.bin").exists()
|
||||
assert not Path(tmp_dir / "model5.bin").exists()
|
||||
if skip_last:
|
||||
assert not Path(tmp_dir / "model-last.bin").exists()
|
||||
else:
|
||||
assert Path(tmp_dir / "model-last.bin").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("objective", VECTOR_OBJECTIVES)
|
||||
|
@ -235,6 +242,7 @@ def test_pretraining_tagger_tok2vec(config):
|
|||
pretrain(filled, tmp_dir)
|
||||
assert Path(tmp_dir / "model0.bin").exists()
|
||||
assert Path(tmp_dir / "model4.bin").exists()
|
||||
assert Path(tmp_dir / "model-last.bin").exists()
|
||||
assert not Path(tmp_dir / "model5.bin").exists()
|
||||
|
||||
|
||||
|
@ -346,3 +354,26 @@ def write_vectors_model(tmp_dir):
|
|||
nlp = English(vocab)
|
||||
nlp.to_disk(nlp_path)
|
||||
return str(nlp_path)
|
||||
|
||||
|
||||
def test_pretrain_default_vectors():
|
||||
nlp = English()
|
||||
nlp.add_pipe("tok2vec")
|
||||
nlp.initialize()
|
||||
|
||||
# default vectors are supported
|
||||
nlp.vocab.vectors = Vectors(shape=(10, 10))
|
||||
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
|
||||
|
||||
# floret vectors are supported
|
||||
nlp.vocab.vectors = Vectors(
|
||||
data=get_current_ops().xp.zeros((10, 10)), mode="floret", hash_count=1
|
||||
)
|
||||
create_pretrain_vectors(1, 1, "cosine")(nlp.vocab, nlp.get_pipe("tok2vec").model)
|
||||
|
||||
# error for no vectors
|
||||
with pytest.raises(ValueError, match="E875"):
|
||||
nlp.vocab.vectors = Vectors()
|
||||
create_pretrain_vectors(1, 1, "cosine")(
|
||||
nlp.vocab, nlp.get_pipe("tok2vec").model
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numpy
|
||||
import tempfile
|
||||
import contextlib
|
||||
import re
|
||||
import srsly
|
||||
from spacy.tokens import Doc
|
||||
from spacy.vocab import Vocab
|
||||
|
@ -95,3 +96,7 @@ def assert_packed_msg_equal(b1, b2):
|
|||
for (k1, v1), (k2, v2) in zip(sorted(msg1.items()), sorted(msg2.items())):
|
||||
assert k1 == k2
|
||||
assert v1 == v2
|
||||
|
||||
|
||||
def normalize_whitespace(s):
|
||||
return re.sub(r"\s+", " ", s)
|
||||
|
|
|
@ -124,6 +124,10 @@ class DocBin:
|
|||
for key, group in doc.spans.items():
|
||||
for span in group:
|
||||
self.strings.add(span.label_)
|
||||
if span.kb_id in span.doc.vocab.strings:
|
||||
self.strings.add(span.kb_id_)
|
||||
if span.id in span.doc.vocab.strings:
|
||||
self.strings.add(span.id_)
|
||||
|
||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||
"""Recover Doc objects from the annotations, using the given vocab.
|
||||
|
|
|
@ -110,6 +110,7 @@ class Doc:
|
|||
kb_id: Union[int, str] = ...,
|
||||
vector: Optional[Floats1d] = ...,
|
||||
alignment_mode: str = ...,
|
||||
span_id: Union[int, str] = ...,
|
||||
) -> Span: ...
|
||||
def similarity(self, other: Union[Doc, Span, Token, Lexeme]) -> float: ...
|
||||
@property
|
||||
|
|
|
@ -530,9 +530,9 @@ cdef class Doc:
|
|||
doc (Doc): The parent document.
|
||||
start_idx (int): The index of the first character of the span.
|
||||
end_idx (int): The index of the first character after the span.
|
||||
label (uint64 or string): A label to attach to the Span, e.g. for
|
||||
label (Union[int, str]): A label to attach to the Span, e.g. for
|
||||
named entities.
|
||||
kb_id (uint64 or string): An ID from a KB to capture the meaning of a
|
||||
kb_id (Union[int, str]): An ID from a KB to capture the meaning of a
|
||||
named entity.
|
||||
vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
|
||||
the span.
|
||||
|
@ -541,14 +541,11 @@ cdef class Doc:
|
|||
with token boundaries), "contract" (span of all tokens completely
|
||||
within the character span), "expand" (span of all tokens at least
|
||||
partially covered by the character span). Defaults to "strict".
|
||||
span_id (Union[int, str]): An identifier to associate with the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/doc#char_span
|
||||
"""
|
||||
if not isinstance(label, int):
|
||||
label = self.vocab.strings.add(label)
|
||||
if not isinstance(kb_id, int):
|
||||
kb_id = self.vocab.strings.add(kb_id)
|
||||
alignment_modes = ("strict", "contract", "expand")
|
||||
if alignment_mode not in alignment_modes:
|
||||
raise ValueError(
|
||||
|
@ -1359,6 +1356,10 @@ cdef class Doc:
|
|||
for group in self.spans.values():
|
||||
for span in group:
|
||||
strings.add(span.label_)
|
||||
if span.kb_id in span.doc.vocab.strings:
|
||||
strings.add(span.kb_id_)
|
||||
if span.id in span.doc.vocab.strings:
|
||||
strings.add(span.id_)
|
||||
# Msgpack doesn't distinguish between lists and tuples, which is
|
||||
# vexing for user data. As a best guess, we *know* that within
|
||||
# keys, we must have tuples. In values we just have to hope
|
||||
|
|
|
@ -95,9 +95,12 @@ class Span:
|
|||
self,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
label: int = ...,
|
||||
kb_id: int = ...,
|
||||
label: Union[int, str] = ...,
|
||||
kb_id: Union[int, str] = ...,
|
||||
vector: Optional[Floats1d] = ...,
|
||||
id: Union[int, str] = ...,
|
||||
alignment_mode: str = ...,
|
||||
span_id: Union[int, str] = ...,
|
||||
) -> Span: ...
|
||||
@property
|
||||
def conjuncts(self) -> Tuple[Token]: ...
|
||||
|
|
|
@ -362,7 +362,7 @@ cdef class Span:
|
|||
result = xp.dot(vector, other.vector) / (self.vector_norm * other.vector_norm)
|
||||
# ensure we get a scalar back (numpy does this automatically but cupy doesn't)
|
||||
return result.item()
|
||||
|
||||
|
||||
cpdef np.ndarray to_array(self, object py_attr_ids):
|
||||
"""Given a list of M attribute IDs, export the tokens to a numpy
|
||||
`ndarray` of shape `(N, M)`, where `N` is the length of the document.
|
||||
|
@ -467,7 +467,6 @@ cdef class Span:
|
|||
if start == self.doc.length - 1:
|
||||
yield Span(self.doc, start, self.doc.length)
|
||||
|
||||
|
||||
@property
|
||||
def ents(self):
|
||||
"""The named entities that fall completely within the span. Returns
|
||||
|
@ -643,21 +642,28 @@ cdef class Span:
|
|||
else:
|
||||
return self.doc[root]
|
||||
|
||||
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None, id=0):
|
||||
def char_span(self, int start_idx, int end_idx, label=0, kb_id=0, vector=None, id=0, alignment_mode="strict", span_id=0):
|
||||
"""Create a `Span` object from the slice `span.text[start : end]`.
|
||||
|
||||
start (int): The index of the first character of the span.
|
||||
end (int): The index of the first character after the span.
|
||||
label (uint64 or string): A label to attach to the Span, e.g. for
|
||||
label (Union[int, str]): A label to attach to the Span, e.g. for
|
||||
named entities.
|
||||
kb_id (uint64 or string): An ID from a KB to capture the meaning of a named entity.
|
||||
kb_id (Union[int, str]): An ID from a KB to capture the meaning of a named entity.
|
||||
vector (ndarray[ndim=1, dtype='float32']): A meaning representation of
|
||||
the span.
|
||||
id (Union[int, str]): Unused.
|
||||
alignment_mode (str): How character indices are aligned to token
|
||||
boundaries. Options: "strict" (character indices must be aligned
|
||||
with token boundaries), "contract" (span of all tokens completely
|
||||
within the character span), "expand" (span of all tokens at least
|
||||
partially covered by the character span). Defaults to "strict".
|
||||
span_id (Union[int, str]): An identifier to associate with the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
"""
|
||||
start_idx += self.c.start_char
|
||||
end_idx += self.c.start_char
|
||||
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector)
|
||||
return self.doc.char_span(start_idx, end_idx, label=label, kb_id=kb_id, vector=vector, alignment_mode=alignment_mode, span_id=span_id)
|
||||
|
||||
@property
|
||||
def conjuncts(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ class SpanGroup:
|
|||
def doc(self) -> Doc: ...
|
||||
@property
|
||||
def has_overlap(self) -> bool: ...
|
||||
def __iter__(self): ...
|
||||
def __len__(self) -> int: ...
|
||||
def append(self, span: Span) -> None: ...
|
||||
def extend(self, spans: Iterable[Span]) -> None: ...
|
||||
|
|
|
@ -158,6 +158,16 @@ cdef class SpanGroup:
|
|||
return self._concat(other)
|
||||
return NotImplemented
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Iterate over the spans in this SpanGroup.
|
||||
YIELDS (Span): A span in this SpanGroup.
|
||||
|
||||
DOCS: https://spacy.io/api/spangroup#iter
|
||||
"""
|
||||
for i in range(self.c.size()):
|
||||
yield self[i]
|
||||
|
||||
def append(self, Span span):
|
||||
"""Add a span to the group. The span must refer to the same Doc
|
||||
object as the span group.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .corpus import Corpus, JsonlCorpus # noqa: F401
|
||||
from .corpus import Corpus, JsonlCorpus, PlainTextCorpus # noqa: F401
|
||||
from .example import Example, validate_examples, validate_get_examples # noqa: F401
|
||||
from .alignment import Alignment # noqa: F401
|
||||
from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
||||
|
|
|
@ -11,7 +11,7 @@ def create_copy_from_base_model(
|
|||
) -> Callable[[Language], Language]:
|
||||
def copy_from_base_model(nlp):
|
||||
if tokenizer:
|
||||
logger.info(f"Copying tokenizer from: {tokenizer}")
|
||||
logger.info("Copying tokenizer from: %s", tokenizer)
|
||||
base_nlp = load_model(tokenizer)
|
||||
if nlp.config["nlp"]["tokenizer"] == base_nlp.config["nlp"]["tokenizer"]:
|
||||
nlp.tokenizer.from_bytes(base_nlp.tokenizer.to_bytes(exclude=["vocab"]))
|
||||
|
@ -23,7 +23,7 @@ def create_copy_from_base_model(
|
|||
)
|
||||
)
|
||||
if vocab:
|
||||
logger.info(f"Copying vocab from: {vocab}")
|
||||
logger.info("Copying vocab from: %s", vocab)
|
||||
# only reload if the vocab is from a different model
|
||||
if tokenizer != vocab:
|
||||
base_nlp = load_model(vocab)
|
||||
|
|
|
@ -29,7 +29,7 @@ def create_docbin_reader(
|
|||
) -> Callable[["Language"], Iterable[Example]]:
|
||||
if path is None:
|
||||
raise ValueError(Errors.E913)
|
||||
util.logger.debug(f"Loading corpus from path: {path}")
|
||||
util.logger.debug("Loading corpus from path: %s", path)
|
||||
return Corpus(
|
||||
path,
|
||||
gold_preproc=gold_preproc,
|
||||
|
@ -58,6 +58,28 @@ def read_labels(path: Path, *, require: bool = False):
|
|||
return srsly.read_json(path)
|
||||
|
||||
|
||||
@util.registry.readers("spacy.PlainTextCorpus.v1")
|
||||
def create_plain_text_reader(
|
||||
path: Optional[Path],
|
||||
min_length: int = 0,
|
||||
max_length: int = 0,
|
||||
) -> Callable[["Language"], Iterable[Doc]]:
|
||||
"""Iterate Example objects from a file or directory of plain text
|
||||
UTF-8 files with one line per doc.
|
||||
|
||||
path (Path): The directory or filename to read from.
|
||||
min_length (int): Minimum document length (in tokens). Shorter documents
|
||||
will be skipped. Defaults to 0, which indicates no limit.
|
||||
max_length (int): Maximum document length (in tokens). Longer documents will
|
||||
be skipped. Defaults to 0, which indicates no limit.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#plaintextcorpus
|
||||
"""
|
||||
if path is None:
|
||||
raise ValueError(Errors.E913)
|
||||
return PlainTextCorpus(path, min_length=min_length, max_length=max_length)
|
||||
|
||||
|
||||
def walk_corpus(path: Union[str, Path], file_type) -> List[Path]:
|
||||
path = util.ensure_path(path)
|
||||
if not path.is_dir() and path.parts[-1].endswith(file_type):
|
||||
|
@ -257,3 +279,52 @@ class JsonlCorpus:
|
|||
# We don't *need* an example here, but it seems nice to
|
||||
# make it match the Corpus signature.
|
||||
yield Example(doc, Doc(nlp.vocab, words=words, spaces=spaces))
|
||||
|
||||
|
||||
class PlainTextCorpus:
|
||||
"""Iterate Example objects from a file or directory of plain text
|
||||
UTF-8 files with one line per doc.
|
||||
|
||||
path (Path): The directory or filename to read from.
|
||||
min_length (int): Minimum document length (in tokens). Shorter documents
|
||||
will be skipped. Defaults to 0, which indicates no limit.
|
||||
max_length (int): Maximum document length (in tokens). Longer documents will
|
||||
be skipped. Defaults to 0, which indicates no limit.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#plaintextcorpus
|
||||
"""
|
||||
|
||||
file_type = "txt"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: Optional[Union[str, Path]],
|
||||
*,
|
||||
min_length: int = 0,
|
||||
max_length: int = 0,
|
||||
) -> None:
|
||||
self.path = util.ensure_path(path)
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
|
||||
def __call__(self, nlp: "Language") -> Iterator[Example]:
|
||||
"""Yield examples from the data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
YIELDS (Example): The example objects.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#plaintextcorpus-call
|
||||
"""
|
||||
for loc in walk_corpus(self.path, ".txt"):
|
||||
with open(loc, encoding="utf-8") as f:
|
||||
for text in f:
|
||||
text = text.rstrip("\r\n")
|
||||
if len(text):
|
||||
doc = nlp.make_doc(text)
|
||||
if self.min_length >= 1 and len(doc) < self.min_length:
|
||||
continue
|
||||
elif self.max_length >= 1 and len(doc) > self.max_length:
|
||||
continue
|
||||
# We don't *need* an example here, but it seems nice to
|
||||
# make it match the Corpus signature.
|
||||
yield Example(doc, doc.copy())
|
||||
|
|
|
@ -62,10 +62,10 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
frozen_components = T["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
resume_components = [p for p in sourced if p not in frozen_components]
|
||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||
logger.info("Pipeline: %s", nlp.pipe_names)
|
||||
if resume_components:
|
||||
with nlp.select_pipes(enable=resume_components):
|
||||
logger.info(f"Resuming training for: {resume_components}")
|
||||
logger.info("Resuming training for: %s", resume_components)
|
||||
nlp.resume_training(sgd=optimizer)
|
||||
# Make sure that listeners are defined before initializing further
|
||||
nlp._link_components()
|
||||
|
@ -73,16 +73,17 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
if T["max_epochs"] == -1:
|
||||
sample_size = 100
|
||||
logger.debug(
|
||||
f"Due to streamed train corpus, using only first {sample_size} "
|
||||
f"examples for initialization. If necessary, provide all labels "
|
||||
f"in [initialize]. More info: https://spacy.io/api/cli#init_labels"
|
||||
"Due to streamed train corpus, using only first %s examples for initialization. "
|
||||
"If necessary, provide all labels in [initialize]. "
|
||||
"More info: https://spacy.io/api/cli#init_labels",
|
||||
sample_size,
|
||||
)
|
||||
nlp.initialize(
|
||||
lambda: islice(train_corpus(nlp), sample_size), sgd=optimizer
|
||||
)
|
||||
else:
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||
logger.info("Initialized pipeline components: %s", nlp.pipe_names)
|
||||
# Detect components with listeners that are not frozen consistently
|
||||
for name, proc in nlp.pipeline:
|
||||
for listener in getattr(
|
||||
|
@ -109,7 +110,7 @@ def init_vocab(
|
|||
) -> None:
|
||||
if lookups:
|
||||
nlp.vocab.lookups = lookups
|
||||
logger.info(f"Added vocab lookups: {', '.join(lookups.tables)}")
|
||||
logger.info("Added vocab lookups: %s", ", ".join(lookups.tables))
|
||||
data_path = ensure_path(data)
|
||||
if data_path is not None:
|
||||
lex_attrs = srsly.read_jsonl(data_path)
|
||||
|
@ -125,11 +126,11 @@ def init_vocab(
|
|||
else:
|
||||
oov_prob = DEFAULT_OOV_PROB
|
||||
nlp.vocab.cfg.update({"oov_prob": oov_prob})
|
||||
logger.info(f"Added {len(nlp.vocab)} lexical entries to the vocab")
|
||||
logger.info("Added %d lexical entries to the vocab", len(nlp.vocab))
|
||||
logger.info("Created vocabulary")
|
||||
if vectors is not None:
|
||||
load_vectors_into_model(nlp, vectors)
|
||||
logger.info(f"Added vectors: {vectors}")
|
||||
logger.info("Added vectors: %s", vectors)
|
||||
# warn if source model vectors are not identical
|
||||
sourced_vectors_hashes = nlp.meta.pop("_sourced_vectors_hashes", {})
|
||||
vectors_hash = hash(nlp.vocab.vectors.to_bytes(exclude=["strings"]))
|
||||
|
@ -191,7 +192,7 @@ def init_tok2vec(
|
|||
if weights_data is not None:
|
||||
layer = get_tok2vec_ref(nlp, P)
|
||||
layer.from_bytes(weights_data)
|
||||
logger.info(f"Loaded pretrained weights from {init_tok2vec}")
|
||||
logger.info("Loaded pretrained weights from %s", init_tok2vec)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -216,13 +217,13 @@ def convert_vectors(
|
|||
nlp.vocab.deduplicate_vectors()
|
||||
else:
|
||||
if vectors_loc:
|
||||
logger.info(f"Reading vectors from {vectors_loc}")
|
||||
logger.info("Reading vectors from %s", vectors_loc)
|
||||
vectors_data, vector_keys, floret_settings = read_vectors(
|
||||
vectors_loc,
|
||||
truncate,
|
||||
mode=mode,
|
||||
)
|
||||
logger.info(f"Loaded vectors from {vectors_loc}")
|
||||
logger.info("Loaded vectors from %s", vectors_loc)
|
||||
else:
|
||||
vectors_data, vector_keys = (None, None)
|
||||
if vector_keys is not None and mode != VectorsMode.floret:
|
||||
|
|
|
@ -26,6 +26,8 @@ def setup_table(
|
|||
return final_cols, final_widths, ["r" for _ in final_widths]
|
||||
|
||||
|
||||
# We cannot rename this method as it's directly imported
|
||||
# and used by external packages such as spacy-loggers.
|
||||
@registry.loggers("spacy.ConsoleLogger.v2")
|
||||
def console_logger(
|
||||
progress_bar: bool = False,
|
||||
|
@ -33,7 +35,27 @@ def console_logger(
|
|||
output_file: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""The ConsoleLogger.v2 prints out training logs in the console and/or saves them to a jsonl file.
|
||||
progress_bar (bool): Whether the logger should print the progress bar.
|
||||
progress_bar (bool): Whether the logger should print a progress bar tracking the steps till the next evaluation pass.
|
||||
console_output (bool): Whether the logger should print the logs on the console.
|
||||
output_file (Optional[Union[str, Path]]): The file to save the training logs to.
|
||||
"""
|
||||
return console_logger_v3(
|
||||
progress_bar=None if progress_bar is False else "eval",
|
||||
console_output=console_output,
|
||||
output_file=output_file,
|
||||
)
|
||||
|
||||
|
||||
@registry.loggers("spacy.ConsoleLogger.v3")
|
||||
def console_logger_v3(
|
||||
progress_bar: Optional[str] = None,
|
||||
console_output: bool = True,
|
||||
output_file: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""The ConsoleLogger.v3 prints out training logs in the console and/or saves them to a jsonl file.
|
||||
progress_bar (Optional[str]): Type of progress bar to show in the console. Allowed values:
|
||||
train - Tracks the number of steps from the beginning of training until the full training run is complete (training.max_steps is reached).
|
||||
eval - Tracks the number of steps between the previous and next evaluation (training.eval_frequency is reached).
|
||||
console_output (bool): Whether the logger should print the logs on the console.
|
||||
output_file (Optional[Union[str, Path]]): The file to save the training logs to.
|
||||
"""
|
||||
|
@ -70,6 +92,7 @@ def console_logger(
|
|||
for name, proc in nlp.pipeline
|
||||
if hasattr(proc, "is_trainable") and proc.is_trainable
|
||||
]
|
||||
max_steps = nlp.config["training"]["max_steps"]
|
||||
eval_frequency = nlp.config["training"]["eval_frequency"]
|
||||
score_weights = nlp.config["training"]["score_weights"]
|
||||
score_cols = [col for col, value in score_weights.items() if value is not None]
|
||||
|
@ -84,6 +107,13 @@ def console_logger(
|
|||
write(msg.row(table_header, widths=table_widths, spacing=spacing))
|
||||
write(msg.row(["-" * width for width in table_widths], spacing=spacing))
|
||||
progress = None
|
||||
expected_progress_types = ("train", "eval")
|
||||
if progress_bar is not None and progress_bar not in expected_progress_types:
|
||||
raise ValueError(
|
||||
Errors.E1048.format(
|
||||
unexpected=progress_bar, expected=expected_progress_types
|
||||
)
|
||||
)
|
||||
|
||||
def log_step(info: Optional[Dict[str, Any]]) -> None:
|
||||
nonlocal progress
|
||||
|
@ -141,11 +171,23 @@ def console_logger(
|
|||
)
|
||||
)
|
||||
if progress_bar:
|
||||
if progress_bar == "train":
|
||||
total = max_steps
|
||||
desc = f"Last Eval Epoch: {info['epoch']}"
|
||||
initial = info["step"]
|
||||
else:
|
||||
total = eval_frequency
|
||||
desc = f"Epoch {info['epoch']+1}"
|
||||
initial = 0
|
||||
# Set disable=None, so that it disables on non-TTY
|
||||
progress = tqdm.tqdm(
|
||||
total=eval_frequency, disable=None, leave=False, file=stderr
|
||||
total=total,
|
||||
disable=None,
|
||||
leave=False,
|
||||
file=stderr,
|
||||
initial=initial,
|
||||
)
|
||||
progress.set_description(f"Epoch {info['epoch']+1}")
|
||||
progress.set_description(desc)
|
||||
|
||||
def finalize() -> None:
|
||||
if output_stream:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user