From 536798f9e3c6ff8c9e7e809bc1d401c7ad81bb03 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 6 Jul 2023 15:20:13 +0200 Subject: [PATCH] Disallow False for first/last arguments of add_pipe (#12793) * Literal True for first/last options * add test case * update docs * remove old redundant test case * black formatting * use Optional typing in docstrings Co-authored-by: Raphael Mitsch --------- Co-authored-by: Raphael Mitsch --- spacy/errors.py | 1 + spacy/language.py | 20 ++++++++++++-------- spacy/tests/pipeline/test_pipe_methods.py | 18 ++++++++++++++++-- website/docs/api/language.mdx | 7 ++++--- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 238acf6f5..faae74781 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -979,6 +979,7 @@ class Errors(metaclass=ErrorsWithCodes): E4007 = ("Span {var} {value} must be {op} Span {existing_var} " "{existing_value}.") E4008 = ("Span {pos}_char {value} does not correspond to a token {pos}.") + E4009 = ("The '{attr}' parameter should be 'None' or 'True', but found '{value}'.") RENAMED_LANGUAGE_CODES = {"xx": "mul", "is": "isl"} diff --git a/spacy/language.py b/spacy/language.py index 8ea29fdfe..51a4a7f93 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -757,8 +757,8 @@ class Language: *, before: Optional[Union[str, int]] = None, after: Optional[Union[str, int]] = None, - first: Optional[bool] = None, - last: Optional[bool] = None, + first: Optional[Literal[True]] = None, + last: Optional[Literal[True]] = None, source: Optional["Language"] = None, config: Dict[str, Any] = SimpleFrozenDict(), raw_config: Optional[Config] = None, @@ -777,8 +777,8 @@ class Language: component directly before. after (Union[str, int]): Name or index of the component to insert new component directly after. - first (bool): If True, insert component first in the pipeline. - last (bool): If True, insert component last in the pipeline. + first (Optional[Literal[True]]): If True, insert component first in the pipeline. + last (Optional[Literal[True]]): If True, insert component last in the pipeline. source (Language): Optional loaded nlp object to copy the pipeline component from. config (Dict[str, Any]): Config parameters to use for this component. @@ -823,18 +823,22 @@ class Language: self, before: Optional[Union[str, int]] = None, after: Optional[Union[str, int]] = None, - first: Optional[bool] = None, - last: Optional[bool] = None, + first: Optional[Literal[True]] = None, + last: Optional[Literal[True]] = None, ) -> int: """Determine where to insert a pipeline component based on the before/ after/first/last values. before (str): Name or index of the component to insert directly before. after (str): Name or index of component to insert directly after. - first (bool): If True, insert component first in the pipeline. - last (bool): If True, insert component last in the pipeline. + first (Optional[Literal[True]]): If True, insert component first in the pipeline. + last (Optional[Literal[True]]): If True, insert component last in the pipeline. RETURNS (int): The index of the new pipeline component. """ + if first is not None and first is not True: + raise ValueError(Errors.E4009.format(attr="first", value=first)) + if last is not None and last is not True: + raise ValueError(Errors.E4009.format(attr="last", value=last)) all_args = {"before": before, "after": after, "first": first, "last": last} if sum(arg is not None for arg in [before, after, first, last]) >= 2: raise ValueError( diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 39611a742..063e5bf67 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -189,6 +189,22 @@ def test_add_pipe_last(nlp, name1, name2): assert nlp.pipeline[-1][0] == name1 +@pytest.mark.parametrize("name1,name2", [("parser", "lambda_pipe")]) +def test_add_pipe_false(nlp, name1, name2): + Language.component("new_pipe2", func=lambda doc: doc) + nlp.add_pipe("new_pipe2", name=name2) + with pytest.raises( + ValueError, + match="The 'last' parameter should be 'None' or 'True', but found 'False'", + ): + nlp.add_pipe("new_pipe", name=name1, last=False) + with pytest.raises( + ValueError, + match="The 'first' parameter should be 'None' or 'True', but found 'False'", + ): + nlp.add_pipe("new_pipe", name=name1, first=False) + + def test_cant_add_pipe_first_and_last(nlp): with pytest.raises(ValueError): nlp.add_pipe("new_pipe", first=True, last=True) @@ -411,8 +427,6 @@ def test_add_pipe_before_after(): nlp.add_pipe("entity_ruler", before="ner", after=2) with pytest.raises(ValueError): nlp.add_pipe("entity_ruler", before=True) - with pytest.raises(ValueError): - nlp.add_pipe("entity_ruler", first=False) def test_disable_enable_pipes(): diff --git a/website/docs/api/language.mdx b/website/docs/api/language.mdx index 7d89327c4..d26d7b96b 100644 --- a/website/docs/api/language.mdx +++ b/website/docs/api/language.mdx @@ -436,7 +436,8 @@ component factory registered using [`@Language.component`](/api/language#component) or [`@Language.factory`](/api/language#factory). Components should be callables that take a `Doc` object, modify it and return it. Only one of `before`, -`after`, `first` or `last` can be set. Default behavior is `last=True`. +`after`, `first` or `last` can be set. The arguments `first` and `last` can +either be `None` or `True`. Default behavior is `last=True`. @@ -471,8 +472,8 @@ component, adds it to the pipeline and returns it. | _keyword-only_ | | | `before` | Component name or index to insert component directly before. ~~Optional[Union[str, int]]~~ | | `after` | Component name or index to insert component directly after. ~~Optional[Union[str, int]]~~ | -| `first` | Insert component first / not first in the pipeline. ~~Optional[bool]~~ | -| `last` | Insert component last / not last in the pipeline. ~~Optional[bool]~~ | +| `first` | Insert component first in the pipeline if set to `True`. ~~Optional[Literal[True]]~~ | +| `last` | Insert component last in the pipeline if set to `True`. ~~Optional[Literal[True]]~~ | | `config` 3 | Optional config parameters to use for this component. Will be merged with the `default_config` specified by the component factory. ~~Dict[str, Any]~~ | | `source` 3 | Optional source pipeline to copy component from. If a source is provided, the `factory_name` is interpreted as the name of the component in the source pipeline. Make sure that the vocab, vectors and settings of the source pipeline match the target pipeline. ~~Optional[Language]~~ | | `validate` 3 | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. ~~bool~~ |