Fix docs example [ci skip]

This commit is contained in:
Ines Montani 2020-10-09 16:03:57 +02:00
parent 9fb3244672
commit 97ff090e49

View File

@ -1403,9 +1403,9 @@ especially useful it you want to pass in a string instead of calling
This example shows the implementation of a pipeline component that fetches This example shows the implementation of a pipeline component that fetches
country meta data via the [REST Countries API](https://restcountries.eu), sets country meta data via the [REST Countries API](https://restcountries.eu), sets
entity annotations for countries, merges entities into one token and sets custom entity annotations for countries and sets custom attributes on the `Doc` and
attributes on the `Doc`, `Span` and `Token` for example, the capital, `Span` for example, the capital, latitude/longitude coordinates and even the
latitude/longitude coordinates and even the country flag. country flag.
```python ```python
### {executable="true"} ### {executable="true"}
@ -1427,54 +1427,46 @@ class RESTCountriesComponent:
# Set up the PhraseMatcher with Doc patterns for each country name # Set up the PhraseMatcher with Doc patterns for each country name
self.matcher = PhraseMatcher(nlp.vocab) self.matcher = PhraseMatcher(nlp.vocab)
self.matcher.add("COUNTRIES", [nlp.make_doc(c) for c in self.countries.keys()]) self.matcher.add("COUNTRIES", [nlp.make_doc(c) for c in self.countries.keys()])
# Register attribute on the Token. We'll be overwriting this based on # Register attributes on the Span. We'll be overwriting this based on
# the matches, so we're only setting a default value, not a getter. # the matches, so we're only setting a default value, not a getter.
Token.set_extension("is_country", default=False) Span.set_extension("is_country", default=None)
Token.set_extension("country_capital", default=False) Span.set_extension("country_capital", default=None)
Token.set_extension("country_latlng", default=False) Span.set_extension("country_latlng", default=None)
Token.set_extension("country_flag", default=False) Span.set_extension("country_flag", default=None)
# Register attributes on Doc and Span via a getter that checks if one of # Register attribute on Doc via a getter that checks if the Doc
# the contained tokens is set to is_country == True. # contains a country entity
Doc.set_extension("has_country", getter=self.has_country) Doc.set_extension("has_country", getter=self.has_country)
Span.set_extension("has_country", getter=self.has_country)
def __call__(self, doc): def __call__(self, doc):
spans = [] # keep the spans for later so we can merge them afterwards spans = [] # keep the spans for later so we can merge them afterwards
for _, start, end in self.matcher(doc): for _, start, end in self.matcher(doc):
# Generate Span representing the entity & set label # Generate Span representing the entity & set label
entity = Span(doc, start, end, label=self.label) entity = Span(doc, start, end, label=self.label)
# Set custom attributes on entity. Can be extended with other data
# returned by the API, like currencies, country code, calling code etc.
entity._.set("is_country", True)
entity._.set("country_capital", self.countries[entity.text]["capital"])
entity._.set("country_latlng", self.countries[entity.text]["latlng"])
entity._.set("country_flag", self.countries[entity.text]["flag"])
spans.append(entity) spans.append(entity)
# Set custom attribute on each token of the entity
# Can be extended with other data returned by the API, like
# currencies, country code, flag, calling code etc.
for token in entity:
token._.set("is_country", True)
token._.set("country_capital", self.countries[entity.text]["capital"])
token._.set("country_latlng", self.countries[entity.text]["latlng"])
token._.set("country_flag", self.countries[entity.text]["flag"])
# Iterate over all spans and merge them into one token
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
# Overwrite doc.ents and add entity be careful not to replace! # Overwrite doc.ents and add entity be careful not to replace!
doc.ents = list(doc.ents) + spans doc.ents = list(doc.ents) + spans
return doc # don't forget to return the Doc! return doc # don't forget to return the Doc!
def has_country(self, tokens): def has_country(self, doc):
"""Getter for Doc and Span attributes. Since the getter is only called """Getter for Doc attributes. Since the getter is only called
when we access the attribute, we can refer to the Token's 'is_country' when we access the attribute, we can refer to the Span's 'is_country'
attribute here, which is already set in the processing step.""" attribute here, which is already set in the processing step."""
return any([t._.get("is_country") for t in tokens]) return any([entity._.get("is_country") for entity in doc.ents])
nlp = English() nlp = English()
nlp.add_pipe("rest_countries", config={"label": "GPE"}) nlp.add_pipe("rest_countries", config={"label": "GPE"})
doc = nlp("Some text about Colombia and the Czech Republic") doc = nlp("Some text about Colombia and the Czech Republic")
print("Pipeline", nlp.pipe_names) # pipeline contains component name print("Pipeline", nlp.pipe_names) # pipeline contains component name
print("Doc has countries", doc._.has_country) # Doc contains countries print("Doc has countries", doc._.has_country) # Doc contains countries
for token in doc: for ent in doc.ents:
if token._.is_country: if ent._.is_country:
print(token.text, token._.country_capital, token._.country_latlng, token._.country_flag) print(ent.text, ent.label_, ent._.country_capital, ent._.country_latlng, ent._.country_flag)
print("Entities", [(e.text, e.label_) for e in doc.ents])
``` ```
In this case, all data can be fetched on initialization in one request. However, In this case, all data can be fetched on initialization in one request. However,