Merge branch 'develop' into spacy.io

This commit is contained in:
Ines Montani 2019-02-24 12:08:15 +01:00
commit f34d6281d6
14 changed files with 102 additions and 102 deletions

View File

@ -1,51 +1,21 @@
environment: environment:
matrix: matrix:
# For Python versions available on Appveyor, see
# http://www.appveyor.com/docs/installed-software#python
#- PYTHON: "C:\\Python27-x64"
#- PYTHON: "C:\\Python34"
#- PYTHON: "C:\\Python35"
#- DISTUTILS_USE_SDK: "1"
#- PYTHON: "C:\\Python34-x64"
#- DISTUTILS_USE_SDK: "1"
- PYTHON: "C:\\Python35-x64" - PYTHON: "C:\\Python35-x64"
- PYTHON: "C:\\Python36-x64" - PYTHON: "C:\\Python36-x64"
- PYTHON: "C:\\Python37-x64" - PYTHON: "C:\\Python37-x64"
install: install:
# We need wheel installed to build wheels # We need wheel installed to build wheels
- "%PYTHON%\\python.exe -m pip install wheel" - "%PYTHON%\\python.exe -m pip install wheel"
- "%PYTHON%\\python.exe -m pip install cython" - "%PYTHON%\\python.exe -m pip install cython"
- "%PYTHON%\\python.exe -m pip install -r requirements.txt" - "%PYTHON%\\python.exe -m pip install -r requirements.txt"
- "%PYTHON%\\python.exe -m pip install -e ." - "%PYTHON%\\python.exe -m pip install -e ."
build: off build: off
test_script: test_script:
# Put your test command here.
# If you don't need to build C extensions on 64-bit Python 3.4,
# you can remove "build.cmd" from the front of the command, as it's
# only needed to support those cases.
# Note that you must use the environment variable %PYTHON% to refer to
# the interpreter you're using - Appveyor does not do anything special
# to put the Python version you want to use on PATH.
- "%PYTHON%\\python.exe -m pytest spacy/ --no-print-logs" - "%PYTHON%\\python.exe -m pytest spacy/ --no-print-logs"
after_test: after_test:
# This step builds your wheels.
# Again, you only need build.cmd if you're building C extensions for
# 64-bit Python 3.4. And you need to use %PYTHON% to get the correct
# interpreter
- "%PYTHON%\\python.exe setup.py bdist_wheel" - "%PYTHON%\\python.exe setup.py bdist_wheel"
artifacts: artifacts:
# bdist_wheel puts your built wheel in the dist directory
- path: dist\* - path: dist\*
branches:
#on_success: except:
# You can use this step to upload your artifacts to a public website. - spacy.io
# See Appveyor's documentation for more details. Or you can simply
# access your wheels from the Appveyor "artifacts" tab for your build.

View File

@ -1,26 +1,20 @@
language: python language: python
sudo: false sudo: false
cache: pip
dist: trusty dist: trusty
group: edge group: edge
python: python:
- "2.7" - "2.7"
- "3.5" - "3.5"
- "3.6" - "3.6"
os: os:
- linux - linux
env: env:
- VIA=compile - VIA=compile
- VIA=flake8 - VIA=flake8
#- VIA=pypi_nightly
install: install:
- "./travis.sh" - "./travis.sh"
- pip install flake8 - pip install flake8
script: script:
- "cat /proc/cpuinfo | grep flags | head -n 1" - "cat /proc/cpuinfo | grep flags | head -n 1"
- "pip install pytest pytest-timeout" - "pip install pytest pytest-timeout"
@ -28,10 +22,10 @@ script:
- if [[ "${VIA}" == "flake8" ]]; then flake8 . --count --exclude=spacy/compat.py,spacy/lang --select=E901,E999,F821,F822,F823 --show-source --statistics; fi - if [[ "${VIA}" == "flake8" ]]; then flake8 . --count --exclude=spacy/compat.py,spacy/lang --select=E901,E999,F821,F822,F823 --show-source --statistics; fi
- if [[ "${VIA}" == "pypi_nightly" ]]; then python -m pytest --tb=native --models --en `python -c "import os.path; import spacy; print(os.path.abspath(os.path.dirname(spacy.__file__)))"`; fi - if [[ "${VIA}" == "pypi_nightly" ]]; then python -m pytest --tb=native --models --en `python -c "import os.path; import spacy; print(os.path.abspath(os.path.dirname(spacy.__file__)))"`; fi
- if [[ "${VIA}" == "sdist" ]]; then python -m pytest --tb=native `python -c "import os.path; import spacy; print(os.path.abspath(os.path.dirname(spacy.__file__)))"`; fi - if [[ "${VIA}" == "sdist" ]]; then python -m pytest --tb=native `python -c "import os.path; import spacy; print(os.path.abspath(os.path.dirname(spacy.__file__)))"`; fi
branches:
except:
- spacy.io
notifications: notifications:
slack: slack:
secure: F8GvqnweSdzImuLL64TpfG0i5rYl89liyr9tmFVsHl4c0DNiDuGhZivUz0M1broS8svE3OPOllLfQbACG/4KxD890qfF9MoHzvRDlp7U+RtwMV/YAkYn8MGWjPIbRbX0HpGdY7O2Rc9Qy4Kk0T8ZgiqXYIqAz2Eva9/9BlSmsJQ= secure: F8GvqnweSdzImuLL64TpfG0i5rYl89liyr9tmFVsHl4c0DNiDuGhZivUz0M1broS8svE3OPOllLfQbACG/4KxD890qfF9MoHzvRDlp7U+RtwMV/YAkYn8MGWjPIbRbX0HpGdY7O2Rc9Qy4Kk0T8ZgiqXYIqAz2Eva9/9BlSmsJQ=
email: false email: false
cache: pip

View File

@ -41,7 +41,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
# add the text classifier to the pipeline if it doesn't exist # add the text classifier to the pipeline if it doesn't exist
# nlp.create_pipe works for built-ins that are registered with spaCy # nlp.create_pipe works for built-ins that are registered with spaCy
if "textcat" not in nlp.pipe_names: if "textcat" not in nlp.pipe_names:
textcat = nlp.create_pipe("textcat") textcat = nlp.create_pipe("textcat", config={
"architecture": "simple_cnn",
"exclusive_classes": True})
nlp.add_pipe(textcat, last=True) nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it # otherwise, get it, so we can add labels to it
else: else:
@ -70,7 +72,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
# batch up the examples using spaCy's minibatch # batch up the examples using spaCy's minibatch
batches = minibatch(train_data, size=compounding(4.0, 16.0, 1.001)) batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches: for batch in batches:
texts, annotations = zip(*batch) texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses) nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
@ -138,6 +140,9 @@ def evaluate(tokenizer, textcat, texts, cats):
fn += 1 fn += 1
precision = tp / (tp + fp) precision = tp / (tp + fp)
recall = tp / (tp + fn) recall = tp / (tp + fn)
if (precision+recall) == 0:
f_score = 0.0
else:
f_score = 2 * (precision * recall) / (precision + recall) f_score = 2 * (precision * recall) / (precision + recall)
return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score} return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score}

View File

@ -1,7 +1,7 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=2.0.1,<2.1.0 preshed>=2.0.1,<2.1.0
thinc>=7.0.1,<7.1.0 thinc>=7.0.2,<7.1.0
blis>=0.2.2,<0.3.0 blis>=0.2.2,<0.3.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
wasabi>=0.0.12,<1.1.0 wasabi>=0.0.12,<1.1.0

View File

@ -227,7 +227,7 @@ def setup_package():
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=2.0.1,<2.1.0", "preshed>=2.0.1,<2.1.0",
"thinc>=7.0.1,<7.1.0", "thinc>=7.0.2,<7.1.0",
"blis>=0.2.2,<0.3.0", "blis>=0.2.2,<0.3.0",
"plac<1.0.0,>=0.9.6", "plac<1.0.0,>=0.9.6",
"requests>=2.13.0,<3.0.0", "requests>=2.13.0,<3.0.0",

View File

@ -72,10 +72,10 @@ def _flatten_add_lengths(seqs, pad=0, drop=0.0):
def _zero_init(model): def _zero_init(model):
def _zero_init_impl(self, X, y): def _zero_init_impl(self, *args, **kwargs):
self.W.fill(0) self.W.fill(0)
model.on_data_hooks.append(_zero_init_impl) model.on_init_hooks.append(_zero_init_impl)
if model.W is not None: if model.W is not None:
model.W.fill(0.0) model.W.fill(0.0)
return model return model
@ -564,18 +564,26 @@ def build_text_classifier(nr_class, width=64, **cfg):
) )
linear_model = _preprocess_doc >> LinearModel(nr_class) linear_model = _preprocess_doc >> LinearModel(nr_class)
model = ( if cfg.get('exclusive_classes'):
(linear_model | cnn_model) output_layer = Softmax(nr_class, nr_class * 2)
>> zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0)) else:
output_layer = (
zero_init(Affine(nr_class, nr_class * 2, drop_factor=0.0))
>> logistic >> logistic
) )
model.tok2vec = tok2vec
model = (
(linear_model | cnn_model)
>> output_layer
)
model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class model.nO = nr_class
model.lsuv = False model.lsuv = False
return model return model
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True, **cfg): def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
""" """
Build a simple CNN text classifier, given a token-to-vector model as inputs. Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the If exclusive_classes=True, a softmax non-linearity is applied, so that the
@ -586,7 +594,7 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=True,
if exclusive_classes: if exclusive_classes:
output_layer = Softmax(nr_class, tok2vec.nO) output_layer = Softmax(nr_class, tok2vec.nO)
else: else:
output_layer = zero_init(Affine(nr_class, tok2vec.nO)) >> logistic output_layer = zero_init(Affine(nr_class, tok2vec.nO, drop_factor=0.0)) >> logistic
model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer model = tok2vec >> flatten_add_lengths >> Pooling(mean_pool) >> output_layer
model.tok2vec = chain(tok2vec, flatten) model.tok2vec = chain(tok2vec, flatten)
model.nO = nr_class model.nO = nr_class

View File

@ -4,13 +4,13 @@
# fmt: off # fmt: off
__title__ = "spacy-nightly" __title__ = "spacy-nightly"
__version__ = "2.1.0a8" __version__ = "2.1.0a9.dev1"
__summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython" __summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython"
__uri__ = "https://spacy.io" __uri__ = "https://spacy.io"
__author__ = "Explosion AI" __author__ = "Explosion AI"
__email__ = "contact@explosion.ai" __email__ = "contact@explosion.ai"
__license__ = "MIT" __license__ = "MIT"
__release__ = True __release__ = False
__download_url__ = "https://github.com/explosion/spacy-models/releases/download" __download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json" __compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"

View File

@ -253,10 +253,10 @@ class EntityRenderer(object):
label = span["label"] label = span["label"]
start = span["start"] start = span["start"]
end = span["end"] end = span["end"]
entity = text[start:end] entity = escape_html(text[start:end])
fragments = text[offset:start].split("\n") fragments = text[offset:start].split("\n")
for i, fragment in enumerate(fragments): for i, fragment in enumerate(fragments):
markup += fragment markup += escape_html(fragment)
if len(fragments) > 1 and i != len(fragments) - 1: if len(fragments) > 1 and i != len(fragments) - 1:
markup += "</br>" markup += "</br>"
if self.ents is None or label.upper() in self.ents: if self.ents is None or label.upper() in self.ents:
@ -265,7 +265,7 @@ class EntityRenderer(object):
else: else:
markup += entity markup += entity
offset = end offset = end
markup += text[offset:] markup += escape_html(text[offset:])
markup = TPL_ENTS.format(content=markup, colors=self.colors) markup = TPL_ENTS.format(content=markup, colors=self.colors)
if title: if title:
markup = TPL_TITLE.format(title=title) + markup markup = TPL_TITLE.format(title=title) + markup

View File

@ -24,7 +24,8 @@ from ..vocab cimport Vocab
from ..syntax import nonproj from ..syntax import nonproj
from ..attrs import POS, ID from ..attrs import POS, ID
from ..parts_of_speech import X from ..parts_of_speech import X
from .._ml import Tok2Vec, build_tagger_model, build_simple_cnn_text_classifier from .._ml import Tok2Vec, build_tagger_model
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
from .._ml import link_vectors_to_models, zero_init, flatten from .._ml import link_vectors_to_models, zero_init, flatten
from .._ml import masked_language_model, create_default_optimizer from .._ml import masked_language_model, create_default_optimizer
from ..errors import Errors, TempErrors from ..errors import Errors, TempErrors
@ -862,8 +863,11 @@ class TextCategorizer(Pipe):
token_vector_width = cfg["token_vector_width"] token_vector_width = cfg["token_vector_width"]
else: else:
token_vector_width = util.env_opt("token_vector_width", 96) token_vector_width = util.env_opt("token_vector_width", 96)
if cfg.get('architecture') == 'simple_cnn':
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg) tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg) return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
else:
return build_text_classifier(nr_class, **cfg)
@property @property
def tok2vec(self): def tok2vec(self):
@ -942,7 +946,7 @@ class TextCategorizer(Pipe):
not_missing = self.model.ops.asarray(not_missing) not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores-truths) / scores.shape[0] d_scores = (scores-truths) / scores.shape[0]
d_scores *= not_missing d_scores *= not_missing
mean_square_error = ((scores-truths)**2).sum(axis=1).mean() mean_square_error = (d_scores**2).sum(axis=1).mean()
return float(mean_square_error), d_scores return float(mean_square_error), d_scores
def add_label(self, label): def add_label(self, label):
@ -964,11 +968,6 @@ class TextCategorizer(Pipe):
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
**kwargs): **kwargs):
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64
if self.model is True: if self.model is True:
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors') self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
self.model = self.Model(len(self.labels), **self.cfg) self.model = self.Model(len(self.labels), **self.cfg)

View File

@ -204,6 +204,8 @@ class ParserModel(Model):
if new_output == self.upper.nO: if new_output == self.upper.nO:
return return
smaller = self.upper smaller = self.upper
with Model.use_device('cpu'):
larger = Affine(new_output, smaller.nI) larger = Affine(new_output, smaller.nI)
# Set nan as value for unseen classes, to prevent prediction. # Set nan as value for unseen classes, to prevent prediction.
larger.W.fill(self.ops.xp.nan) larger.W.fill(self.ops.xp.nan)

View File

@ -0,0 +1,16 @@
# coding: utf8
from __future__ import unicode_literals
from spacy import displacy
from spacy.tokens import Doc, Span
def test_issue2728(en_vocab):
"""Test that displaCy ENT visualizer escapes HTML correctly."""
doc = Doc(en_vocab, words=["test", "<RELEASE>", "test"])
doc.ents = [Span(doc, 0, 1, label="TEST")]
html = displacy.render(doc, style="ent")
assert "&lt;RELEASE&gt;" in html
doc.ents = [Span(doc, 1, 2, label="TEST")]
html = displacy.render(doc, style="ent")
assert "&lt;RELEASE&gt;" in html

View File

@ -107,8 +107,8 @@ details and examples.
> >
> ```python > ```python
> from spacy.attrs import ORTH, LEMMA > from spacy.attrs import ORTH, LEMMA
> case = [{"don't": [{ORTH: "do"}, {ORTH: "n't", LEMMA: "not"}]}] > case = [{ORTH: "do"}, {ORTH: "n't", LEMMA: "not"}]
> tokenizer.add_special_case(case) > tokenizer.add_special_case("don't", case)
> ``` > ```
| Name | Type | Description | | Name | Type | Description |

View File

@ -8,7 +8,7 @@ menu:
- ['Changelog', 'changelog'] - ['Changelog', 'changelog']
--- ---
spaCy is compatible with **64-bit CPython 2.6+/3.3+** and runs on spaCy is compatible with **64-bit CPython 2.7+/3.4+** and runs on
**Unix/Linux**, **macOS/OS X** and **Windows**. The latest spaCy releases are **Unix/Linux**, **macOS/OS X** and **Windows**. The latest spaCy releases are
available over [pip](https://pypi.python.org/pypi/spacy) and available over [pip](https://pypi.python.org/pypi/spacy) and
[conda](https://anaconda.org/conda-forge/spacy). [conda](https://anaconda.org/conda-forge/spacy).

View File

@ -10,11 +10,11 @@ menu:
spaCy v2.1 has focussed primarily on stability and performance, solidifying the spaCy v2.1 has focussed primarily on stability and performance, solidifying the
design changes introduced in [v2.0](/usage/v2). As well as smaller models, design changes introduced in [v2.0](/usage/v2). As well as smaller models,
faster runtime, and many bug-fixes, v2.1 also introduces experimental support faster runtime, and many bug fixes, v2.1 also introduces experimental support
for some exciting new NLP innovations. For the full changelog, see the for some exciting new NLP innovations. For the full changelog, see the
[release notes on GitHub](https://github.com/explosion/spaCy/releases/tag/v2.1.0). [release notes on GitHub](https://github.com/explosion/spaCy/releases/tag/v2.1.0).
### BERT/ULMFit/Elmo-style pre-training ### BERT/ULMFit/Elmo-style pre-training {tag="experimental"}
> #### Example > #### Example
> >
@ -115,33 +115,6 @@ or `POS` for finding sequences of the same part-of-speech tags.
</Infobox> </Infobox>
### Components and languages via entry points
> #### Example
>
> ```python
> from setuptools import setup
> setup(
> name="custom_extension_package",
> entry_points={
> "spacy_factories": ["your_component = component:ComponentFactory"]
> "spacy_languages": ["xyz = language:XYZLanguage"]
> }
> )
> ```
Using entry points, model packages and extension packages can now define their
own `"spacy_factories"` and `"spacy_languages"`, which will be added to the
built-in factories and languages. If a package in the same environment exposes
spaCy entry points, all of this happens automatically and no further user action
is required.
<Infobox>
**Usage:** [Using entry points](/usage/saving-loading#entry-points)
</Infobox>
### Retokenizer for merging and splitting ### Retokenizer for merging and splitting
> #### Example > #### Example
@ -169,6 +142,33 @@ deprecated.
</Infobox> </Infobox>
### Components and languages via entry points
> #### Example
>
> ```python
> from setuptools import setup
> setup(
> name="custom_extension_package",
> entry_points={
> "spacy_factories": ["your_component = component:ComponentFactory"]
> "spacy_languages": ["xyz = language:XYZLanguage"]
> }
> )
> ```
Using entry points, model packages and extension packages can now define their
own `"spacy_factories"` and `"spacy_languages"`, which will be added to the
built-in factories and languages. If a package in the same environment exposes
spaCy entry points, all of this happens automatically and no further user action
is required.
<Infobox>
**Usage:** [Using entry points](/usage/saving-loading#entry-points)
</Infobox>
### Improved documentation ### Improved documentation
Although it looks pretty much the same, we've rebuilt the entire documentation Although it looks pretty much the same, we've rebuilt the entire documentation
@ -210,6 +210,12 @@ if all of your models are up to date, you can run the
</Infobox> </Infobox>
- Due to difficulties linking our new
[`blis`](https://github.com/explosion/cython-blis) for faster
platform-independent matrix multiplication, this nightly release currently
**doesn't work on Python 2.7 on Windows**. We expect this to be corrected in
the future.
- While the [`Matcher`](/api/matcher) API is fully backwards compatible, its - While the [`Matcher`](/api/matcher) API is fully backwards compatible, its
algorithm has changed to fix a number of bugs and performance issues. This algorithm has changed to fix a number of bugs and performance issues. This
means that the `Matcher` in v2.1.x may produce different results compared to means that the `Matcher` in v2.1.x may produce different results compared to