mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 02:36:32 +03:00
Merge remote-tracking branch 'upstream/develop' into feature/docs-layers
# Conflicts: # website/docs/usage/layers-architectures.md
This commit is contained in:
commit
422df9c2e2
|
@ -297,9 +297,7 @@ def ensure_pathy(path):
|
||||||
return Pathy(path)
|
return Pathy(path)
|
||||||
|
|
||||||
|
|
||||||
def git_sparse_checkout(
|
def git_sparse_checkout(repo: str, subpath: str, dest: Path, *, branch: str = "master"):
|
||||||
repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None
|
|
||||||
):
|
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
msg.fail("Destination of checkout must not exist", exits=1)
|
msg.fail("Destination of checkout must not exist", exits=1)
|
||||||
if not dest.parent.exists():
|
if not dest.parent.exists():
|
||||||
|
@ -323,21 +321,30 @@ def git_sparse_checkout(
|
||||||
# This is the "clone, but don't download anything" part.
|
# This is the "clone, but don't download anything" part.
|
||||||
cmd = (
|
cmd = (
|
||||||
f"git clone {repo} {tmp_dir} --no-checkout --depth 1 "
|
f"git clone {repo} {tmp_dir} --no-checkout --depth 1 "
|
||||||
"--filter=blob:none" # <-- The key bit
|
f"--filter=blob:none " # <-- The key bit
|
||||||
|
f"-b {branch}"
|
||||||
)
|
)
|
||||||
if branch is not None:
|
|
||||||
cmd = f"{cmd} -b {branch}"
|
|
||||||
run_command(cmd, capture=True)
|
run_command(cmd, capture=True)
|
||||||
# Now we need to find the missing filenames for the subpath we want.
|
# Now we need to find the missing filenames for the subpath we want.
|
||||||
# Looking for this 'rev-list' command in the git --help? Hah.
|
# Looking for this 'rev-list' command in the git --help? Hah.
|
||||||
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
|
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
|
||||||
ret = run_command(cmd, capture=True)
|
ret = run_command(cmd, capture=True)
|
||||||
missings = "\n".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
|
repo = _from_http_to_git(repo)
|
||||||
# Now pass those missings into another bit of git internals
|
# Now pass those missings into another bit of git internals
|
||||||
run_command(
|
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
|
||||||
f"git -C {tmp_dir} fetch-pack --stdin {repo}", capture=True, stdin=missings
|
cmd = f"git -C {tmp_dir} fetch-pack {repo} {missings}"
|
||||||
)
|
run_command(cmd, capture=True)
|
||||||
# And finally, we can checkout our subpath
|
# And finally, we can checkout our subpath
|
||||||
run_command(f"git -C {tmp_dir} checkout {branch} {subpath}")
|
cmd = f"git -C {tmp_dir} checkout {branch} {subpath}"
|
||||||
|
run_command(cmd)
|
||||||
# We need Path(name) to make sure we also support subdirectories
|
# We need Path(name) to make sure we also support subdirectories
|
||||||
shutil.move(str(tmp_dir / Path(subpath)), str(dest))
|
shutil.move(str(tmp_dir / Path(subpath)), str(dest))
|
||||||
|
|
||||||
|
|
||||||
|
def _from_http_to_git(repo):
|
||||||
|
if repo.startswith("http://"):
|
||||||
|
repo = repo.replace(r"http://", r"https://")
|
||||||
|
if repo.startswith(r"https://"):
|
||||||
|
repo = repo.replace("https://", "git@").replace("/", ":", 1)
|
||||||
|
repo = f"{repo}.git"
|
||||||
|
return repo
|
||||||
|
|
|
@ -43,7 +43,7 @@ def project_clone(name: str, dest: Path, *, repo: str = about.__projects__) -> N
|
||||||
git_sparse_checkout(repo, name, dest)
|
git_sparse_checkout(repo, name, dest)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
err = f"Could not clone '{name}' from repo '{repo_name}'"
|
err = f"Could not clone '{name}' from repo '{repo_name}'"
|
||||||
msg.fail(err)
|
msg.fail(err, exits=1)
|
||||||
msg.good(f"Cloned '{name}' from {repo_name}", project_dir)
|
msg.good(f"Cloned '{name}' from {repo_name}", project_dir)
|
||||||
if not (project_dir / PROJECT_FILE).exists():
|
if not (project_dir / PROJECT_FILE).exists():
|
||||||
msg.warn(f"No {PROJECT_FILE} found in directory")
|
msg.warn(f"No {PROJECT_FILE} found in directory")
|
||||||
|
@ -78,6 +78,7 @@ def check_clone(name: str, dest: Path, repo: str) -> None:
|
||||||
if not dest.parent.exists():
|
if not dest.parent.exists():
|
||||||
# We're not creating parents, parent dir should exist
|
# We're not creating parents, parent dir should exist
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"Can't clone project, parent directory doesn't exist: {dest.parent}",
|
f"Can't clone project, parent directory doesn't exist: {dest.parent}. "
|
||||||
|
f"Create the necessary folder(s) first before continuing.",
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,7 @@ _concat_icons = CONCAT_ICONS.replace("\u00B0", "")
|
||||||
|
|
||||||
_currency = r"\$¢£€¥฿"
|
_currency = r"\$¢£€¥฿"
|
||||||
_quotes = CONCAT_QUOTES.replace("'", "")
|
_quotes = CONCAT_QUOTES.replace("'", "")
|
||||||
|
_units = UNITS.replace("%", "")
|
||||||
|
|
||||||
_prefixes = (
|
_prefixes = (
|
||||||
LIST_PUNCT
|
LIST_PUNCT
|
||||||
|
@ -26,7 +27,7 @@ _suffixes = (
|
||||||
r"(?<=[0-9])\+",
|
r"(?<=[0-9])\+",
|
||||||
r"(?<=°[FfCcKk])\.",
|
r"(?<=°[FfCcKk])\.",
|
||||||
r"(?<=[0-9])(?:[{c}])".format(c=_currency),
|
r"(?<=[0-9])(?:[{c}])".format(c=_currency),
|
||||||
r"(?<=[0-9])(?:{u})".format(u=UNITS),
|
r"(?<=[0-9])(?:{u})".format(u=_units),
|
||||||
r"(?<=[{al}{e}{q}(?:{c})])\.".format(
|
r"(?<=[{al}{e}{q}(?:{c})])\.".format(
|
||||||
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES, c=_currency
|
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES, c=_currency
|
||||||
),
|
),
|
||||||
|
|
|
@ -42,6 +42,7 @@ cdef cppclass StateC:
|
||||||
RingBufferC _hist
|
RingBufferC _hist
|
||||||
int length
|
int length
|
||||||
int offset
|
int offset
|
||||||
|
int n_pushes
|
||||||
int _s_i
|
int _s_i
|
||||||
int _b_i
|
int _b_i
|
||||||
int _e_i
|
int _e_i
|
||||||
|
@ -49,6 +50,7 @@ cdef cppclass StateC:
|
||||||
|
|
||||||
__init__(const TokenC* sent, int length) nogil:
|
__init__(const TokenC* sent, int length) nogil:
|
||||||
cdef int PADDING = 5
|
cdef int PADDING = 5
|
||||||
|
this.n_pushes = 0
|
||||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||||
|
@ -335,6 +337,7 @@ cdef cppclass StateC:
|
||||||
this.set_break(this.B_(0).l_edge)
|
this.set_break(this.B_(0).l_edge)
|
||||||
if this._b_i > this._break:
|
if this._b_i > this._break:
|
||||||
this._break = -1
|
this._break = -1
|
||||||
|
this.n_pushes += 1
|
||||||
|
|
||||||
void pop() nogil:
|
void pop() nogil:
|
||||||
if this._s_i >= 1:
|
if this._s_i >= 1:
|
||||||
|
@ -351,6 +354,7 @@ cdef cppclass StateC:
|
||||||
this._buffer[this._b_i] = this.S(0)
|
this._buffer[this._b_i] = this.S(0)
|
||||||
this._s_i -= 1
|
this._s_i -= 1
|
||||||
this.shifted[this.B(0)] = True
|
this.shifted[this.B(0)] = True
|
||||||
|
this.n_pushes -= 1
|
||||||
|
|
||||||
void add_arc(int head, int child, attr_t label) nogil:
|
void add_arc(int head, int child, attr_t label) nogil:
|
||||||
if this.has_head(child):
|
if this.has_head(child):
|
||||||
|
@ -431,6 +435,7 @@ cdef cppclass StateC:
|
||||||
this._break = src._break
|
this._break = src._break
|
||||||
this.offset = src.offset
|
this.offset = src.offset
|
||||||
this._empty_token = src._empty_token
|
this._empty_token = src._empty_token
|
||||||
|
this.n_pushes = src.n_pushes
|
||||||
|
|
||||||
void fast_forward() nogil:
|
void fast_forward() nogil:
|
||||||
# space token attachement policy:
|
# space token attachement policy:
|
||||||
|
|
|
@ -36,6 +36,10 @@ cdef class StateClass:
|
||||||
hist[i] = self.c.get_hist(i+1)
|
hist[i] = self.c.get_hist(i+1)
|
||||||
return hist
|
return hist
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_pushes(self):
|
||||||
|
return self.c.n_pushes
|
||||||
|
|
||||||
def is_final(self):
|
def is_final(self):
|
||||||
return self.c.is_final()
|
return self.c.is_final()
|
||||||
|
|
||||||
|
|
|
@ -289,7 +289,14 @@ class Tagger(Pipe):
|
||||||
err = Errors.E1006.format(name="Tagger")
|
err = Errors.E1006.format(name="Tagger")
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
self.set_output(len(self.labels))
|
self.set_output(len(self.labels))
|
||||||
self.model.initialize(X=doc_sample)
|
if doc_sample:
|
||||||
|
label_sample = [
|
||||||
|
self.model.ops.alloc2f(len(doc), len(self.labels))
|
||||||
|
for doc in doc_sample
|
||||||
|
]
|
||||||
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
|
else:
|
||||||
|
self.model.initialize()
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -279,14 +279,14 @@ cdef class Parser(Pipe):
|
||||||
# Chop sequences into lengths of this many transitions, to make the
|
# Chop sequences into lengths of this many transitions, to make the
|
||||||
# batch uniform length.
|
# batch uniform length.
|
||||||
# We used to randomize this, but it's not clear that actually helps?
|
# We used to randomize this, but it's not clear that actually helps?
|
||||||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
max_pushes = self.cfg["update_with_oracle_cut_size"]
|
||||||
states, golds, max_steps = self._init_gold_batch(
|
states, golds, _ = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=cut_size
|
max_length=max_pushes
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
max_steps = max([len(eg.x) for eg in examples])
|
max_pushes = max([len(eg.x) for eg in examples])
|
||||||
if not states:
|
if not states:
|
||||||
return losses
|
return losses
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
|
@ -302,7 +302,8 @@ cdef class Parser(Pipe):
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||||
|
if s.n_pushes < max_pushes and not s.is_final()]
|
||||||
|
|
||||||
backprop_tok2vec(golds)
|
backprop_tok2vec(golds)
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
|
|
|
@ -84,9 +84,8 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.add_pipe("textcat")
|
|
||||||
# Set exclusive labels
|
# Set exclusive labels
|
||||||
textcat.model.attrs["multi_label"] = False
|
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}})
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
@ -103,9 +102,8 @@ def test_overfitting_IO():
|
||||||
test_text = "I am happy."
|
test_text = "I am happy."
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
cats = doc.cats
|
cats = doc.cats
|
||||||
# note that by default, exclusive_classes = false so we need a bigger error margin
|
assert cats["POSITIVE"] > 0.9
|
||||||
assert cats["POSITIVE"] > 0.8
|
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
||||||
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
|
||||||
|
|
||||||
# Also test the results are still the same after IO
|
# Also test the results are still the same after IO
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
|
@ -113,8 +111,8 @@ def test_overfitting_IO():
|
||||||
nlp2 = util.load_model_from_path(tmp_dir)
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
doc2 = nlp2(test_text)
|
doc2 = nlp2(test_text)
|
||||||
cats2 = doc2.cats
|
cats2 = doc2.cats
|
||||||
assert cats2["POSITIVE"] > 0.8
|
assert cats2["POSITIVE"] > 0.9
|
||||||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)
|
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
||||||
|
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
|
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
|
||||||
|
|
|
@ -3,8 +3,9 @@ title: Layers and Model Architectures
|
||||||
teaser: Power spaCy components with custom neural networks
|
teaser: Power spaCy components with custom neural networks
|
||||||
menu:
|
menu:
|
||||||
- ['Type Signatures', 'type-sigs']
|
- ['Type Signatures', 'type-sigs']
|
||||||
- ['Defining Sublayers', 'sublayers']
|
- ['Swapping Architectures', 'swap-architectures']
|
||||||
- ['PyTorch & TensorFlow', 'frameworks']
|
- ['PyTorch & TensorFlow', 'frameworks']
|
||||||
|
- ['Thinc Models', 'thinc']
|
||||||
- ['Trainable Components', 'components']
|
- ['Trainable Components', 'components']
|
||||||
next: /usage/projects
|
next: /usage/projects
|
||||||
---
|
---
|
||||||
|
@ -22,8 +23,6 @@ its model architecture. The architecture is like a recipe for the network, and
|
||||||
you can't change the recipe once the dish has already been prepared. You have to
|
you can't change the recipe once the dish has already been prepared. You have to
|
||||||
make a new one.
|
make a new one.
|
||||||
|
|
||||||
![Diagram of a pipeline component with its model](../images/layers-architectures.svg)
|
|
||||||
|
|
||||||
## Type signatures {#type-sigs}
|
## Type signatures {#type-sigs}
|
||||||
|
|
||||||
<!-- TODO: update example, maybe simplify definition? -->
|
<!-- TODO: update example, maybe simplify definition? -->
|
||||||
|
@ -92,9 +91,13 @@ code.
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
## Defining sublayers {#sublayers}
|
## Swapping model architectures {#swap-architectures}
|
||||||
|
|
||||||
Model architecture functions often accept **sublayers as arguments**, so that
|
<!-- TODO: textcat example, using different architecture in the config -->
|
||||||
|
|
||||||
|
### Defining sublayers {#sublayers}
|
||||||
|
|
||||||
|
Model architecture functions often accept **sublayers as arguments**, so that
|
||||||
you can try **substituting a different layer** into the network. Depending on
|
you can try **substituting a different layer** into the network. Depending on
|
||||||
how the architecture function is structured, you might be able to define your
|
how the architecture function is structured, you might be able to define your
|
||||||
network structure entirely through the [config system](/usage/training#config),
|
network structure entirely through the [config system](/usage/training#config),
|
||||||
|
@ -114,62 +117,37 @@ approaches. And if you want to define your own solution, all you need to do is
|
||||||
register a ~~Model[List[Doc], List[Floats2d]]~~ architecture function, and
|
register a ~~Model[List[Doc], List[Floats2d]]~~ architecture function, and
|
||||||
you'll be able to try it out in any of the spaCy components.
|
you'll be able to try it out in any of the spaCy components.
|
||||||
|
|
||||||
<!-- TODO: example of switching sublayers -->
|
<!-- TODO: example of swapping sublayers -->
|
||||||
|
|
||||||
### Registering new architectures
|
|
||||||
|
|
||||||
- Recap concept, link to config docs.
|
|
||||||
|
|
||||||
## Wrapping PyTorch, TensorFlow and other frameworks {#frameworks}
|
## Wrapping PyTorch, TensorFlow and other frameworks {#frameworks}
|
||||||
|
|
||||||
<!-- TODO: this is copied over from the Thinc docs and we probably want to shorten it and make it more spaCy-specific -->
|
Thinc allows you to [wrap models](https://thinc.ai/docs/usage-frameworks)
|
||||||
|
written in other machine learning frameworks like PyTorch, TensorFlow and MXNet
|
||||||
|
using a unified [`Model`](https://thinc.ai/docs/api-model) API. As well as
|
||||||
|
**wrapping whole models**, Thinc lets you call into an external framework for
|
||||||
|
just **part of your model**: you can have a model where you use PyTorch just for
|
||||||
|
the transformer layers, using "native" Thinc layers to do fiddly input and
|
||||||
|
output transformations and add on task-specific "heads", as efficiency is less
|
||||||
|
of a consideration for those parts of the network.
|
||||||
|
|
||||||
Thinc allows you to wrap models written in other machine learning frameworks
|
<!-- TODO: custom tagger implemented in PyTorch, wrapped as Thinc model, link off to project (with notebook?) -->
|
||||||
like PyTorch, TensorFlow and MXNet using a unified
|
|
||||||
[`Model`](https://thinc.ai/docs/api-model) API. As well as **wrapping whole
|
|
||||||
models**, Thinc lets you call into an external framework for just **part of your
|
|
||||||
model**: you can have a model where you use PyTorch just for the transformer
|
|
||||||
layers, using "native" Thinc layers to do fiddly input and output
|
|
||||||
transformations and add on task-specific "heads", as efficiency is less of a
|
|
||||||
consideration for those parts of the network.
|
|
||||||
|
|
||||||
Thinc uses a special class, [`Shim`](https://thinc.ai/docs/api-model#shim), to
|
## Implementing models in Thinc {#thinc}
|
||||||
hold references to external objects. This allows each wrapper space to define a
|
|
||||||
custom type, with whatever attributes and methods are helpful, to assist in
|
|
||||||
managing the communication between Thinc and the external library. The
|
|
||||||
[`Model`](https://thinc.ai/docs/api-model#model) class holds `shim` instances in
|
|
||||||
a separate list, and communicates with the shims about updates, serialization,
|
|
||||||
changes of device, etc.
|
|
||||||
|
|
||||||
The wrapper will receive each batch of inputs, convert them into a suitable form
|
<!-- TODO: use same example as above, custom tagger, but implemented in Thinc, link off to Thinc docs where appropriate -->
|
||||||
for the underlying model instance, and pass them over to the shim, which will
|
|
||||||
**manage the actual communication** with the model. The output is then passed
|
|
||||||
back into the wrapper, and converted for use in the rest of the network. The
|
|
||||||
equivalent procedure happens during backpropagation. Array conversion is handled
|
|
||||||
via the [DLPack](https://github.com/dmlc/dlpack) standard wherever possible, so
|
|
||||||
that data can be passed between the frameworks **without copying the data back**
|
|
||||||
to the host device unnecessarily.
|
|
||||||
|
|
||||||
| Framework | Wrapper layer | Shim | DLPack |
|
|
||||||
| -------------- | ------------------------------------------------------------------------- | --------------------------------------------------------- | --------------- |
|
|
||||||
| **PyTorch** | [`PyTorchWrapper`](https://thinc.ai/docs/api-layers#pytorchwrapper) | [`PyTorchShim`](https://thinc.ai/docs/api-model#shims) | ✅ |
|
|
||||||
| **TensorFlow** | [`TensorFlowWrapper`](https://thinc.ai/docs/api-layers#tensorflowwrapper) | [`TensorFlowShim`](https://thinc.ai/docs/api-model#shims) | ❌ <sup>1</sup> |
|
|
||||||
| **MXNet** | [`MXNetWrapper`](https://thinc.ai/docs/api-layers#mxnetwrapper) | [`MXNetShim`](https://thinc.ai/docs/api-model#shims) | ✅ |
|
|
||||||
|
|
||||||
1. DLPack support in TensorFlow is now
|
|
||||||
[available](<(https://github.com/tensorflow/tensorflow/issues/24453)>) but
|
|
||||||
still experimental.
|
|
||||||
|
|
||||||
<!-- TODO:
|
|
||||||
- Explain concept
|
|
||||||
- Link off to notebook
|
|
||||||
-->
|
|
||||||
|
|
||||||
## Models for trainable components {#components}
|
## Models for trainable components {#components}
|
||||||
|
|
||||||
|
<!-- TODO:
|
||||||
|
|
||||||
- Interaction with `predict`, `get_loss` and `set_annotations`
|
- Interaction with `predict`, `get_loss` and `set_annotations`
|
||||||
- Initialization life-cycle with `begin_training`.
|
- Initialization life-cycle with `begin_training`.
|
||||||
- Link to relation extraction notebook.
|
|
||||||
|
Example: relation extraction component (implemented as project template)
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
![Diagram of a pipeline component with its model](../images/layers-architectures.svg)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def update(self, examples):
|
def update(self, examples):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user