Merge pull request #6252 from svlandeg/fix/docs

This commit is contained in:
Ines Montani 2020-10-14 16:43:12 +02:00 committed by GitHub
commit cb47f25cda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 12 deletions

View File

@ -195,7 +195,7 @@ class Tagger(TrainablePipe):
validate_examples(examples, "Tagger.update") validate_examples(examples, "Tagger.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples]) tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
for sc in tag_scores: for sc in tag_scores:
@ -227,22 +227,24 @@ class Tagger(TrainablePipe):
DOCS: https://nightly.spacy.io/api/tagger#rehearse DOCS: https://nightly.spacy.io/api/tagger#rehearse
""" """
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "Tagger.rehearse") validate_examples(examples, "Tagger.rehearse")
docs = [eg.predicted for eg in examples] docs = [eg.predicted for eg in examples]
if self._rehearsal_model is None: if self._rehearsal_model is None:
return return losses
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
guesses, backprop = self.model.begin_update(docs) guesses, backprop = self.model.begin_update(docs)
target = self._rehearsal_model(examples) target = self._rehearsal_model(examples)
gradient = guesses - target gradient = guesses - target
backprop(gradient) backprop(gradient)
self.finish_update(sgd) self.finish_update(sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum() losses[self.name] += (gradient**2).sum()
return losses
def get_loss(self, examples, scores): def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of documents and """Find the loss and gradient of loss for the batch of documents and

View File

@ -116,7 +116,7 @@ cdef class TrainablePipe(Pipe):
validate_examples(examples, "TrainablePipe.update") validate_examples(examples, "TrainablePipe.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples]) scores, bp_scores = self.model.begin_update([eg.predicted for eg in examples])
loss, d_scores = self.get_loss(examples, scores) loss, d_scores = self.get_loss(examples, scores)

View File

@ -503,7 +503,7 @@ overview of the `TrainablePipe` methods used by
</Infobox> </Infobox>
### Example: Entity elation extraction component {#component-rel} ### Example: Entity relation extraction component {#component-rel}
This section outlines an example use-case of implementing a **novel relation This section outlines an example use-case of implementing a **novel relation
extraction component** from scratch. We'll implement a binary relation extraction component** from scratch. We'll implement a binary relation
@ -618,7 +618,7 @@ we can define our relation model in a config file as such:
# ... # ...
[model.get_candidates] [model.get_candidates]
@misc = "rel_cand_generator.v2" @misc = "rel_cand_generator.v1"
max_length = 20 max_length = 20
[model.create_candidate_tensor] [model.create_candidate_tensor]
@ -687,8 +687,8 @@ Before the model can be used, it needs to be
[initialized](/usage/training#initialization). This function receives a callback [initialized](/usage/training#initialization). This function receives a callback
to access the full **training data set**, or a representative sample. This data to access the full **training data set**, or a representative sample. This data
set can be used to deduce all **relevant labels**. Alternatively, a list of set can be used to deduce all **relevant labels**. Alternatively, a list of
labels can be provided to `initialize`, or you can call the labels can be provided to `initialize`, or you can call
`RelationExtractoradd_label` directly. The number of labels defines the output `RelationExtractor.add_label` directly. The number of labels defines the output
dimensionality of the network, and will be used to do dimensionality of the network, and will be used to do
[shape inference](https://thinc.ai/docs/usage-models#validation) throughout the [shape inference](https://thinc.ai/docs/usage-models#validation) throughout the
layers of the neural network. This is triggered by calling layers of the neural network. This is triggered by calling
@ -729,7 +729,7 @@ and its internal model can be trained and used to make predictions.
During training, the function [`update`](/api/pipe#update) is invoked which During training, the function [`update`](/api/pipe#update) is invoked which
delegates to delegates to
[`Model.begin_update`](https://thinc.ai/docs/api-model#begin_update) and a [`Model.begin_update`](https://thinc.ai/docs/api-model#begin_update) and a
[`get_loss`](/api/pipe#get_loss) function that **calculate the loss** for a [`get_loss`](/api/pipe#get_loss) function that **calculates the loss** for a
batch of examples, as well as the **gradient** of loss that will be used to batch of examples, as well as the **gradient** of loss that will be used to
update the weights of the model layers. Thinc provides several update the weights of the model layers. Thinc provides several
[loss functions](https://thinc.ai/docs/api-loss) that can be used for the [loss functions](https://thinc.ai/docs/api-loss) that can be used for the