mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Fix pickle for ngram suggester (#12486)
This commit is contained in:
parent
b228875600
commit
57ee1212de
|
@ -1,5 +1,6 @@
|
||||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
from thinc.types import Ragged, Ints2d, Floats2d
|
from thinc.types import Ragged, Ints2d, Floats2d
|
||||||
|
@ -82,13 +83,9 @@ class Suggester(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.ngram_suggester.v1")
|
def ngram_suggester(
|
||||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
|
||||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
) -> Ragged:
|
||||||
array of integers. The array has two columns, indicating the start and end
|
|
||||||
position."""
|
|
||||||
|
|
||||||
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
|
||||||
if ops is None:
|
if ops is None:
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
spans = []
|
spans = []
|
||||||
|
@ -114,7 +111,14 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||||
assert output.dataXd.ndim == 2
|
assert output.dataXd.ndim == 2
|
||||||
return output
|
return output
|
||||||
|
|
||||||
return ngram_suggester
|
|
||||||
|
@registry.misc("spacy.ngram_suggester.v1")
|
||||||
|
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||||
|
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||||
|
array of integers. The array has two columns, indicating the start and end
|
||||||
|
position."""
|
||||||
|
|
||||||
|
return partial(ngram_suggester, sizes=sizes)
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import numpy
|
import numpy
|
||||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||||
from thinc.api import get_current_ops, Ragged
|
from thinc.api import get_current_ops, NumpyOps, Ragged
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -577,3 +577,21 @@ def test_set_candidates(name):
|
||||||
assert len(docs[0].spans["candidates"]) == 9
|
assert len(docs[0].spans["candidates"]) == 9
|
||||||
assert docs[0].spans["candidates"][0].text == "Just"
|
assert docs[0].spans["candidates"][0].text == "Just"
|
||||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||||
|
@pytest.mark.parametrize("n_process", [1, 2])
|
||||||
|
def test_spancat_multiprocessing(name, n_process):
|
||||||
|
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
|
||||||
|
nlp = Language()
|
||||||
|
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||||
|
train_examples = make_examples(nlp)
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
texts = [
|
||||||
|
"Just a sentence.",
|
||||||
|
"I like London and Berlin",
|
||||||
|
"I like Berlin",
|
||||||
|
"I eat ham.",
|
||||||
|
]
|
||||||
|
docs = list(nlp.pipe(texts, n_process=n_process))
|
||||||
|
assert len(docs) == len(texts)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user