mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Fix pickle for ngram suggester (#12486)
This commit is contained in:
parent
b228875600
commit
57ee1212de
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast, Union
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d
|
||||
|
@ -82,13 +83,9 @@ class Suggester(Protocol):
|
|||
...
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
position."""
|
||||
|
||||
def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
|
||||
def ngram_suggester(
|
||||
docs: Iterable[Doc], sizes: List[int], *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
|
@ -114,7 +111,14 @@ def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
|||
assert output.dataXd.ndim == 2
|
||||
return output
|
||||
|
||||
return ngram_suggester
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
array of integers. The array has two columns, indicating the start and end
|
||||
position."""
|
||||
|
||||
return partial(ngram_suggester, sizes=sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_range_suggester.v1")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal, assert_almost_equal
|
||||
from thinc.api import get_current_ops, Ragged
|
||||
from thinc.api import get_current_ops, NumpyOps, Ragged
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
|
@ -577,3 +577,21 @@ def test_set_candidates(name):
|
|||
assert len(docs[0].spans["candidates"]) == 9
|
||||
assert docs[0].spans["candidates"][0].text == "Just"
|
||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", SPANCAT_COMPONENTS)
|
||||
@pytest.mark.parametrize("n_process", [1, 2])
|
||||
def test_spancat_multiprocessing(name, n_process):
|
||||
if isinstance(get_current_ops, NumpyOps) or n_process < 2:
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe(name, config={"spans_key": SPAN_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
texts = [
|
||||
"Just a sentence.",
|
||||
"I like London and Berlin",
|
||||
"I like Berlin",
|
||||
"I eat ham.",
|
||||
]
|
||||
docs = list(nlp.pipe(texts, n_process=n_process))
|
||||
assert len(docs) == len(texts)
|
||||
|
|
Loading…
Reference in New Issue
Block a user