1
1
mirror of https://github.com/explosion/spaCy.git synced 2025-01-18 13:34:13 +03:00
spaCy/spacy/tests/util.py

140 lines
4.3 KiB
Python
Raw Normal View History

# coding: utf-8
from __future__ import unicode_literals
2017-01-12 18:49:57 +03:00
import numpy
import tempfile
import shutil
import contextlib
import srsly
from pathlib import Path
from spacy import Errors
💫 Refactor test suite (#2568) ## Description Related issues: #2379 (should be fixed by separating model tests) * **total execution time down from > 300 seconds to under 60 seconds** 🎉 * removed all model-specific tests that could only really be run manually anyway – those will now live in a separate test suite in the [`spacy-models`](https://github.com/explosion/spacy-models) repository and are already integrated into our new model training infrastructure * changed all relative imports to absolute imports to prepare for moving the test suite from `/spacy/tests` to `/tests` (it'll now always test against the installed version) * merged old regression tests into collections, e.g. `test_issue1001-1500.py` (about 90% of the regression tests are very short anyways) * tidied up and rewrote existing tests wherever possible ### Todo - [ ] move tests to `/tests` and adjust CI commands accordingly - [x] move model test suite from internal repo to `spacy-models` - [x] ~~investigate why `pipeline/test_textcat.py` is flakey~~ - [x] review old regression tests (leftover files) and see if they can be merged, simplified or deleted - [ ] update documentation on how to run tests ### Types of change enhancement, tests ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-07-25 00:38:44 +03:00
from spacy.tokens import Doc, Span
from spacy.attrs import POS, TAG, HEAD, DEP, LEMMA
💫 Refactor test suite (#2568) ## Description Related issues: #2379 (should be fixed by separating model tests) * **total execution time down from > 300 seconds to under 60 seconds** 🎉 * removed all model-specific tests that could only really be run manually anyway – those will now live in a separate test suite in the [`spacy-models`](https://github.com/explosion/spacy-models) repository and are already integrated into our new model training infrastructure * changed all relative imports to absolute imports to prepare for moving the test suite from `/spacy/tests` to `/tests` (it'll now always test against the installed version) * merged old regression tests into collections, e.g. `test_issue1001-1500.py` (about 90% of the regression tests are very short anyways) * tidied up and rewrote existing tests wherever possible ### Todo - [ ] move tests to `/tests` and adjust CI commands accordingly - [x] move model test suite from internal repo to `spacy-models` - [x] ~~investigate why `pipeline/test_textcat.py` is flakey~~ - [x] review old regression tests (leftover files) and see if they can be merged, simplified or deleted - [ ] update documentation on how to run tests ### Types of change enhancement, tests ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [ ] My changes don't require a change to the documentation, or if they do, I've added all required information.
2018-07-25 00:38:44 +03:00
from spacy.compat import path2str
@contextlib.contextmanager
def make_tempfile(mode="r"):
f = tempfile.TemporaryFile(mode=mode)
yield f
f.close()
@contextlib.contextmanager
def make_tempdir():
d = Path(tempfile.mkdtemp())
yield d
shutil.rmtree(path2str(d))
2020-03-25 14:28:12 +03:00
def get_doc(
vocab, words=[], pos=None, heads=None, deps=None, tags=None, ents=None, lemmas=None
):
"""Create Doc object from given vocab, words and annotations."""
if deps and not heads:
heads = [0] * len(deps)
headings = []
values = []
annotations = [pos, heads, deps, lemmas, tags]
possible_headings = [POS, HEAD, DEP, LEMMA, TAG]
for a, annot in enumerate(annotations):
if annot is not None:
if len(annot) != len(words):
raise ValueError(Errors.E189)
headings.append(possible_headings[a])
if annot is not heads:
values.extend(annot)
for value in values:
vocab.strings.add(value)
doc = Doc(vocab, words=words)
# if there are any other annotations, set them
if headings:
attrs = doc.to_array(headings)
j = 0
for annot in annotations:
if annot:
if annot is heads:
for i in range(len(words)):
if attrs.ndim == 1:
attrs[i] = heads[i]
else:
2020-03-25 14:28:12 +03:00
attrs[i, j] = heads[i]
else:
for i in range(len(words)):
if attrs.ndim == 1:
attrs[i] = doc.vocab.strings[annot[i]]
else:
attrs[i, j] = doc.vocab.strings[annot[i]]
j += 1
doc.from_array(headings, attrs)
# finally, set the entities
2017-01-12 14:25:10 +03:00
if ents:
doc.ents = [
Span(doc, start, end, label=doc.vocab.strings[label])
for start, end, label in ents
]
return doc
def apply_transition_sequence(parser, doc, sequence):
"""Perform a series of pre-specified transitions, to put the parser in a
desired state."""
for action_name in sequence:
if "-" in action_name:
move, label = action_name.split("-")
parser.add_label(label)
with parser.step_through(doc) as stepwise:
for transition in sequence:
stepwise.transition(transition)
2017-01-12 18:49:57 +03:00
def add_vecs_to_vocab(vocab, vectors):
"""Add list of vector tuples to given vocab. All vectors need to have the
same length. Format: [("text", [1, 2, 3])]"""
length = len(vectors[0][1])
2017-10-31 20:25:08 +03:00
vocab.reset_vectors(width=length)
for word, vec in vectors:
2017-10-31 20:25:08 +03:00
vocab.set_vector(word, vector=vec)
return vocab
2017-01-12 18:49:57 +03:00
def get_cosine(vec1, vec2):
"""Get cosine for two given vectors"""
return numpy.dot(vec1, vec2) / (numpy.linalg.norm(vec1) * numpy.linalg.norm(vec2))
def assert_docs_equal(doc1, doc2):
"""Compare two Doc objects and assert that they're equal. Tests for tokens,
tags, dependencies and entities."""
assert [t.orth for t in doc1] == [t.orth for t in doc2]
assert [t.pos for t in doc1] == [t.pos for t in doc2]
assert [t.tag for t in doc1] == [t.tag for t in doc2]
assert [t.head.i for t in doc1] == [t.head.i for t in doc2]
assert [t.dep for t in doc1] == [t.dep for t in doc2]
assert [t.is_sent_start for t in doc1] == [t.is_sent_start for t in doc2]
assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2]
assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2]
for ent1, ent2 in zip(doc1.ents, doc2.ents):
assert ent1.start == ent2.start
assert ent1.end == ent2.end
assert ent1.label == ent2.label
assert ent1.kb_id == ent2.kb_id
def assert_packed_msg_equal(b1, b2):
"""Assert that two packed msgpack messages are equal."""
msg1 = srsly.msgpack_loads(b1)
msg2 = srsly.msgpack_loads(b2)
assert sorted(msg1.keys()) == sorted(msg2.keys())
for (k1, v1), (k2, v2) in zip(sorted(msg1.items()), sorted(msg2.items())):
assert k1 == k2
assert v1 == v2