diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 9e8e87239..836bdac67 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -1547,24 +1547,33 @@ to `Doc.user_span_hooks` and `Doc.user_token_hooks`. | Name | Customizes | | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `user_hooks` | [`Doc.vector`](/api/doc#vector), [`Doc.has_vector`](/api/doc#has_vector), [`Doc.vector_norm`](/api/doc#vector_norm), [`Doc.sents`](/api/doc#sents) | +| `user_hooks` | [`Doc.similarity`](/api/doc#similarity), [`Doc.vector`](/api/doc#vector), [`Doc.has_vector`](/api/doc#has_vector), [`Doc.vector_norm`](/api/doc#vector_norm), [`Doc.sents`](/api/doc#sents) | | `user_token_hooks` | [`Token.similarity`](/api/token#similarity), [`Token.vector`](/api/token#vector), [`Token.has_vector`](/api/token#has_vector), [`Token.vector_norm`](/api/token#vector_norm), [`Token.conjuncts`](/api/token#conjuncts) | | `user_span_hooks` | [`Span.similarity`](/api/span#similarity), [`Span.vector`](/api/span#vector), [`Span.has_vector`](/api/span#has_vector), [`Span.vector_norm`](/api/span#vector_norm), [`Span.root`](/api/span#root) | ```python ### Add custom similarity hooks +from spacy.language import Language + + class SimilarityModel: - def __init__(self, model): - self._model = model + def __init__(self, name: str, index: int): + self.name = name + self.index = index def __call__(self, doc): doc.user_hooks["similarity"] = self.similarity doc.user_span_hooks["similarity"] = self.similarity doc.user_token_hooks["similarity"] = self.similarity + return doc def similarity(self, obj1, obj2): - y = self._model([obj1.vector, obj2.vector]) - return float(y[0]) + return obj1.vector[self.index] + obj2.vector[self.index] + + +@Language.factory("similarity_component", default_config={"index": 0}) +def create_similarity_component(nlp, name, index: int): + return SimilarityModel(name, index) ``` ## Developing plugins and wrappers {#plugins}