diff --git a/spacy/ml/tb_framework.pyx b/spacy/ml/tb_framework.pyx index e98d59a8a..62d876e33 100644 --- a/spacy/ml/tb_framework.pyx +++ b/spacy/ml/tb_framework.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, cdivision=True, boundscheck=False -from typing import List, Tuple, Any, Optional, cast +from typing import List, Tuple, Any, Optional, TypeVar, cast from libc.string cimport memset, memcpy from libc.stdlib cimport calloc, free, realloc from libcpp.vector cimport vector @@ -10,12 +10,14 @@ from thinc.api import uniform_init, glorot_uniform_init, zero_init from thinc.api import NumpyOps from thinc.backends.linalg cimport Vec, VecVec from thinc.backends.cblas cimport CBlas -from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d +from thinc.types import Floats1d, Floats2d, Floats3d, Floats4d +from thinc.types import Ints1d, Ints2d from ..errors import Errors from ..pipeline._parser_internals import _beam_utils from ..pipeline._parser_internals.batch import GreedyBatch -from ..pipeline._parser_internals.transition_system cimport c_transition_batch, TransitionSystem +from ..pipeline._parser_internals.transition_system cimport c_transition_batch, c_apply_actions +from ..pipeline._parser_internals.transition_system cimport TransitionSystem from ..pipeline._parser_internals.stateclass cimport StateC, StateClass from ..tokens.doc import Doc from ..util import registry @@ -43,18 +45,28 @@ def TransitionModel( tok2vec_projected = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) # type: ignore tok2vec_projected.set_dim("nO", hidden_width) + # FIXME: we use `upper` as a container for the upper layer's + # weights and biases. Thinc optimizers cannot handle resizing + # of parameters. So, when the parser model is resized, we + # construct a new `upper` layer, which has a different key in + # the optimizer. Once the optimizer supports parameter resizing, + # we can replace the `upper` layer by `upper_W` and `upper_b` + # parameters in this model. + upper = Linear(nO=None, nI=hidden_width, init_W=zero_init) + return Model( name="parser_model", forward=forward, init=init, - layers=[tok2vec_projected], - refs={"tok2vec": tok2vec_projected}, + layers=[tok2vec_projected, upper], + refs={ + "tok2vec": tok2vec_projected, + "upper": upper, + }, params={ "lower_W": None, # Floats2d W for the hidden layer "lower_b": None, # Floats1d bias for the hidden layer "lower_pad": None, # Floats1d padding for the hidden layer - "upper_W": None, # Floats2d W for the output layer - "upper_b": None, # Floats1d bias for the output layer }, dims={ "nO": None, # Output size @@ -74,29 +86,30 @@ def TransitionModel( def resize_output(model: Model, new_nO: int) -> Model: old_nO = model.maybe_get_dim("nO") + upper = model.get_ref("upper") if old_nO is None: model.set_dim("nO", new_nO) + upper.set_dim("nO", new_nO) + upper.initialize() return model elif new_nO <= old_nO: return model - elif model.has_param("upper_W"): + elif upper.has_param("W"): nH = model.get_dim("nH") - new_W = model.ops.alloc2f(new_nO, nH) - new_b = model.ops.alloc1f(new_nO) - old_W = model.get_param("upper_W") - old_b = model.get_param("upper_b") + new_upper = Linear(nO=new_nO, nI=nH, init_W=zero_init) + new_upper.initialize() + new_W = new_upper.get_param("W") + new_b = new_upper.get_param("b") + old_W = upper.get_param("W") + old_b = upper.get_param("b") new_W[:old_nO] = old_W # type: ignore new_b[:old_nO] = old_b # type: ignore for i in range(old_nO, new_nO): model.attrs["unseen_classes"].add(i) - model.set_param("upper_W", new_W) - model.set_param("upper_b", new_b) + model.layers[-1] = new_upper + model.set_ref("upper", new_upper) # TODO: Avoid this private intrusion model._dims["nO"] = new_nO - if model.has_grad("upper_W"): - model.set_grad("upper_W", model.get_param("upper_W") * 0) - if model.has_grad("upper_b"): - model.set_grad("upper_b", model.get_param("upper_b") * 0) return model @@ -113,9 +126,7 @@ def init( inferred_nO = _infer_nO(Y) if inferred_nO is not None: current_nO = model.maybe_get_dim("nO") - if current_nO is None: - model.set_dim("nO", inferred_nO) - elif current_nO != inferred_nO: + if current_nO is None or current_nO != inferred_nO: model.attrs["resize_output"](model, inferred_nO) nO = model.get_dim("nO") nP = model.get_dim("nP") @@ -127,9 +138,6 @@ def init( Wl = ops.alloc2f(nH * nP, nF * nI) bl = ops.alloc1f(nH * nP) padl = ops.alloc1f(nI) - Wu = ops.alloc2f(nO, nH) - bu = ops.alloc1f(nO) - Wu = zero_init(ops, Wu.shape) # Wl = zero_init(ops, Wl.shape) Wl = glorot_uniform_init(ops, Wl.shape) padl = uniform_init(ops, padl.shape) # type: ignore @@ -137,41 +145,50 @@ def init( model.set_param("lower_W", Wl) model.set_param("lower_b", bl) model.set_param("lower_pad", padl) - model.set_param("upper_W", Wu) - model.set_param("upper_b", bu) # model = _lsuv_init(model) return model -def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool): +InWithoutActions = Tuple[List[Doc], TransitionSystem] +InWithActions = Tuple[List[Doc], TransitionSystem, List[Ints1d]] +InT = TypeVar("InT", InWithoutActions, InWithActions) + +def forward(model, docs_moves: InT, is_train: bool): + if len(docs_moves) == 2: + docs, moves = docs_moves + actions = None + else: + docs, moves, actions = docs_moves + beam_width = model.attrs["beam_width"] lower_pad = model.get_param("lower_pad") tok2vec = model.get_ref("tok2vec") - docs, moves = docs_moves states = moves.init_batch(docs) tokvecs, backprop_tok2vec = tok2vec(docs, is_train) tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) seen_mask = _get_seen_mask(model) + # Fixme: support actions in forward_cpu if beam_width == 1 and not is_train and isinstance(model.ops, NumpyOps): - return _forward_greedy_cpu(model, moves, states, feats, seen_mask) + return _forward_greedy_cpu(model, moves, states, feats, seen_mask, actions=actions) else: - return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train) + return _forward_fallback(model, moves, states, tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train, actions=actions) + def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[StateClass], np.ndarray feats, - np.ndarray[np.npy_bool, ndim=1] seen_mask): - cdef vector[StateC *] c_states + np.ndarray[np.npy_bool, ndim=1] seen_mask, actions: Optional[List[Ints1d]]=None): + cdef vector[StateC*] c_states cdef StateClass state for state in states: if not state.is_final(): c_states.push_back(state.c) - weights = get_c_weights(model, feats.data, seen_mask) + weights = get_c_weights(model, feats.data, seen_mask) # Precomputed features have rows for each token, plus one for padding. cdef int n_tokens = feats.shape[0] - 1 sizes = get_c_sizes(model, c_states.size(), n_tokens) cdef CBlas cblas = model.ops.cblas() - scores = _parseC(cblas, moves, &c_states[0], weights, sizes) + scores = _parseC(cblas, moves, &c_states[0], weights, sizes, actions=actions) def backprop(dY): raise ValueError(Errors.E1038) @@ -179,20 +196,25 @@ def _forward_greedy_cpu(model: Model, TransitionSystem moves, states: List[State return (states, scores), backprop cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states, - WeightsC weights, SizesC sizes): + WeightsC weights, SizesC sizes, actions: Optional[List[Ints1d]]=None): cdef int i, j cdef vector[StateC *] unfinished cdef ActivationsC activations = alloc_activations(sizes) cdef np.ndarray step_scores + cdef np.ndarray step_actions scores = [] while sizes.states >= 1: step_scores = numpy.empty((sizes.states, sizes.classes), dtype="f") + step_actions = actions[0] if actions is not None else None with nogil: - predict_states(cblas, &activations, step_scores.data, states, &weights, sizes) - # Validate actions, argmax, take action. - c_transition_batch(moves, states, step_scores.data, sizes.classes, - sizes.states) + predict_states(cblas, &activations, step_scores.data, states, &weights, sizes) + if actions is None: + # Validate actions, argmax, take action. + c_transition_batch(moves, states, step_scores.data, sizes.classes, + sizes.states) + else: + c_apply_actions(moves, states, step_actions.data, sizes.states) for i in range(sizes.states): if not states[i].is_final(): unfinished.push_back(states[i]) @@ -201,15 +223,17 @@ cdef list _parseC(CBlas cblas, TransitionSystem moves, StateC** states, sizes.states = unfinished.size() scores.append(step_scores) unfinished.clear() + actions = actions[1:] if actions is not None else None free_activations(&activations) return scores -def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool): + +def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateClass], tokvecs, backprop_tok2vec, feats, backprop_feats, seen_mask, is_train: bool, + actions: Optional[List[Ints1d]]=None): nF = model.get_dim("nF") + upper = model.get_ref("upper") lower_b = model.get_param("lower_b") - upper_W = model.get_param("upper_W") - upper_b = model.get_param("upper_b") nH = model.get_dim("nH") nP = model.get_dim("nP") @@ -240,14 +264,17 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC preacts = ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) assert preacts.shape[0] == len(batch.get_unfinished_states()), preacts.shape statevecs, which = ops.maxout(preacts) - # Multiply the state-vector by the scores weights and add the bias, - # to get the logits. - scores = ops.gemm(statevecs, upper_W, trans2=True) - scores += upper_b + # We don't use upper's backprop, since we want to backprop for + # all states at once, rather than a single state. + scores = upper.predict(statevecs) scores[:, seen_mask] = ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. cpu_scores = ops.to_numpy(scores) - batch.advance(cpu_scores) + if actions is None: + batch.advance(cpu_scores) + else: + batch.advance_with_actions(actions[0]) + actions = actions[1:] all_scores.append(scores) if is_train: # Remember intermediate results for the backprop. @@ -270,10 +297,11 @@ def _forward_fallback(model: Model, moves: TransitionSystem, states: List[StateC d_scores *= seen_mask == False # Calculate the gradients for the parameters of the upper layer. # The weight gemm is (nS, nO) @ (nS, nH).T - model.inc_grad("upper_b", d_scores.sum(axis=0)) - model.inc_grad("upper_W", ops.gemm(d_scores, statevecs, trans1=True)) + upper.inc_grad("b", d_scores.sum(axis=0)) + upper.inc_grad("W", ops.gemm(d_scores, statevecs, trans1=True)) # Now calculate d_statevecs, by backproping through the upper linear layer. # This gemm is (nS, nO) @ (nO, nH) + upper_W = upper.get_param("W") d_statevecs = ops.gemm(d_scores, upper_W) # Backprop through the maxout activation d_preacts = ops.backprop_maxout(d_statevecs, which, nP) @@ -295,11 +323,10 @@ def _forward_reference( """Slow reference implementation, without the precomputation""" nF = model.get_dim("nF") tok2vec = model.get_ref("tok2vec") + upper = model.get_ref("upper") lower_pad = model.get_param("lower_pad") lower_W = model.get_param("lower_W") lower_b = model.get_param("lower_b") - upper_W = model.get_param("upper_W") - upper_b = model.get_param("upper_b") nH = model.get_dim("nH") nP = model.get_dim("nP") nO = model.get_dim("nO") @@ -330,10 +357,9 @@ def _forward_reference( preacts2f += lower_b preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) statevecs, which = ops.maxout(preacts) - # Multiply the state-vector by the scores weights and add the bias, - # to get the logits. - scores = model.ops.gemm(statevecs, upper_W, trans2=True) - scores += upper_b + # We don't use upper's backprop, since we want to backprop for + # all states at once, rather than a single state. + scores = upper.predict(statevecs) scores[:, seen_mask] = model.ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. next_states = moves.transition_states(next_states, scores) @@ -366,10 +392,11 @@ def _forward_reference( assert d_scores.shape == (nS, nO), d_scores.shape # Calculate the gradients for the parameters of the upper layer. # The weight gemm is (nS, nO) @ (nS, nH).T - model.inc_grad("upper_b", d_scores.sum(axis=0)) - model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) + upper.inc_grad("b", d_scores.sum(axis=0)) + upper.inc_grad("W", model.ops.gemm(d_scores, statevecs, trans1=True)) # Now calculate d_statevecs, by backproping through the upper linear layer. # This gemm is (nS, nO) @ (nO, nH) + upper_W = upper.get_param("W") d_statevecs = model.ops.gemm(d_scores, upper_W) # Backprop through the maxout activation d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP) @@ -514,9 +541,10 @@ def _lsuv_init(model: Model): cdef WeightsC get_c_weights(model, const float* feats, np.ndarray[np.npy_bool, ndim=1] seen_mask) except *: + upper = model.get_ref("upper") cdef np.ndarray lower_b = model.get_param("lower_b") - cdef np.ndarray upper_W = model.get_param("upper_W") - cdef np.ndarray upper_b = model.get_param("upper_b") + cdef np.ndarray upper_W = upper.get_param("W") + cdef np.ndarray upper_b = upper.get_param("b") cdef WeightsC output output.feat_weights = feats diff --git a/spacy/pipeline/_parser_internals/batch.pyx b/spacy/pipeline/_parser_internals/batch.pyx index 7928fb0b9..93b8e08c1 100644 --- a/spacy/pipeline/_parser_internals/batch.pyx +++ b/spacy/pipeline/_parser_internals/batch.pyx @@ -32,6 +32,9 @@ class GreedyBatch(Batch): def advance(self, scores): self._next_states = self._moves.transition_states(self._next_states, scores) + def advance_with_actions(self, actions): + self._next_states = self._moves.apply_transitions(self._next_states, actions) + def get_states(self): return self._states diff --git a/spacy/pipeline/_parser_internals/transition_system.pxd b/spacy/pipeline/_parser_internals/transition_system.pxd index d2bc0f781..c8ebd8b27 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pxd +++ b/spacy/pipeline/_parser_internals/transition_system.pxd @@ -54,5 +54,9 @@ cdef class TransitionSystem: cdef int set_costs(self, int* is_valid, weight_t* costs, const StateC* state, gold) except -1 + +cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions, + int batch_size) nogil + cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores, int nr_class, int batch_size) nogil diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 201128283..dd18606c1 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -155,6 +155,17 @@ cdef class TransitionSystem: action.do(state.c, action.label) state.c.history.push_back(action.clas) + def apply_actions(self, states, const int[::1] actions): + assert len(states) == actions.shape[0] + cdef StateClass state + cdef vector[StateC*] c_states + c_states.resize(len(states)) + cdef int i + for (i, state) in enumerate(states): + c_states[i] = state.c + c_apply_actions(self, &c_states[0], &actions[0], actions.shape[0]) + return [state for state in states if not state.c.is_final()] + def transition_states(self, states, float[:, ::1] scores): assert len(states) == scores.shape[0] cdef StateClass state @@ -279,6 +290,18 @@ cdef class TransitionSystem: return self +cdef void c_apply_actions(TransitionSystem moves, StateC** states, const int* actions, + int batch_size) nogil: + cdef int i + cdef Transition action + cdef StateC* state + for i in range(batch_size): + state = states[i] + action = moves.c[actions[i]] + action.do(state, action.label) + state.history.push_back(action.clas) + + cdef void c_transition_batch(TransitionSystem moves, StateC** states, const float* scores, int nr_class, int batch_size) nogil: is_valid = calloc(moves.n_moves, sizeof(int)) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index ce1d7e717..13592e0be 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -1,5 +1,6 @@ # cython: infer_types=True, cdivision=True, boundscheck=False, binding=True from __future__ import print_function +from typing import List from cymem.cymem cimport Pool cimport numpy as np from itertools import islice @@ -12,11 +13,12 @@ import contextlib import srsly from thinc.api import set_dropout_rate, CupyOps, get_array_module from thinc.extra.search cimport Beam +from thinc.types import Ints1d import numpy.random import numpy import warnings -from ._parser_internals.stateclass cimport StateClass +from ._parser_internals.stateclass cimport StateC, StateClass from ..tokens.doc cimport Doc from .trainable_pipe import TrainablePipe from ._parser_internals cimport _beam_utils @@ -359,7 +361,42 @@ class Parser(TrainablePipe): def rehearse(self, examples, sgd=None, losses=None, **cfg): """Perform a "rehearsal" update, to prevent catastrophic forgetting.""" - raise NotImplementedError + if losses is None: + losses = {} + for multitask in self._multitasks: + if hasattr(multitask, 'rehearse'): + multitask.rehearse(examples, losses=losses, sgd=sgd) + if self._rehearsal_model is None: + return None + losses.setdefault(self.name, 0.0) + validate_examples(examples, "Parser.rehearse") + docs = [eg.predicted for eg in examples] + # This is pretty dirty, but the NER can resize itself in init_batch, + # if labels are missing. We therefore have to check whether we need to + # expand our model output. + self._resize() + # Prepare the stepwise model, and get the callback for finishing the batch + set_dropout_rate(self._rehearsal_model, 0.0) + set_dropout_rate(self.model, 0.0) + (student_states, student_scores), backprop_scores = self.model.begin_update((docs, self.moves)) + actions = states2actions(student_states) + _, teacher_scores = self._rehearsal_model.predict((docs, self.moves, actions)) + + teacher_scores = self.model.ops.xp.vstack(teacher_scores) + student_scores = self.model.ops.xp.vstack(student_scores) + assert teacher_scores.shape == student_scores.shape + + d_scores = (student_scores - teacher_scores) / teacher_scores.shape[0] + # If all weights for an output are 0 in the original model, don't + # supervise that output. This allows us to add classes. + loss = (d_scores**2).sum() / d_scores.size + backprop_scores((student_states, d_scores)) + + if sgd is not None: + self.finish_update(sgd) + losses[self.name] += loss + + return losses def update_beam(self, examples, *, beam_width, drop=0., sgd=None, losses=None, beam_density=0.0): @@ -474,3 +511,26 @@ def _change_attrs(model, **kwargs): model.attrs.pop(key) else: model.attrs[key] = value + + +def states2actions(states: List[StateClass]) -> List[Ints1d]: + cdef int step + cdef StateClass state + cdef StateC* c_state + actions = [] + while True: + step = len(actions) + + step_actions = [] + for state in states: + c_state = state.c + if step < c_state.history.size(): + step_actions.append(c_state.history[step]) + + # We are done if we have exhausted all histories. + if len(step_actions) == 0: + break + + actions.append(numpy.array(step_actions, dtype="i")) + + return actions diff --git a/spacy/tests/parser/test_parse.py b/spacy/tests/parser/test_parse.py index 9fc8115f5..0719544d3 100644 --- a/spacy/tests/parser/test_parse.py +++ b/spacy/tests/parser/test_parse.py @@ -1,4 +1,5 @@ import pytest +import numpy from numpy.testing import assert_equal from thinc.api import Adam @@ -175,6 +176,57 @@ def test_parser_parse_one_word_sentence(en_vocab, en_parser, words): assert doc[0].dep != 0 +def test_parser_apply_actions(en_vocab, en_parser): + words = ["I", "ate", "pizza"] + words2 = ["Eat", "more", "pizza", "!"] + doc1 = Doc(en_vocab, words=words) + doc2 = Doc(en_vocab, words=words2) + docs = [doc1, doc2] + + moves = en_parser.moves + moves.add_action(0, "") + moves.add_action(1, "") + moves.add_action(2, "nsubj") + moves.add_action(3, "obj") + moves.add_action(2, "amod") + + actions = [ + numpy.array([0, 0], dtype="i"), + numpy.array([2, 0], dtype="i"), + numpy.array([0, 4], dtype="i"), + numpy.array([3, 3], dtype="i"), + numpy.array([1, 1], dtype="i"), + numpy.array([1, 1], dtype="i"), + numpy.array([0], dtype="i"), + numpy.array([1], dtype="i"), + ] + + states = moves.init_batch(docs) + active_states = states + + for step_actions in actions: + active_states = moves.apply_actions(active_states, step_actions) + + assert len(active_states) == 0 + + for (state, doc) in zip(states, docs): + moves.set_annotations(state, doc) + + assert docs[0][0].head.i == 1 + assert docs[0][0].dep_ == "nsubj" + assert docs[0][1].head.i == 1 + assert docs[0][1].dep_ == "ROOT" + assert docs[0][2].head.i == 1 + assert docs[0][2].dep_ == "obj" + + assert docs[1][0].head.i == 0 + assert docs[1][0].dep_ == "ROOT" + assert docs[1][1].head.i == 2 + assert docs[1][1].dep_ == "amod" + assert docs[1][2].head.i == 0 + assert docs[1][2].dep_ == "obj" + + @pytest.mark.skip( reason="The step_through API was removed (but should be brought back)" ) diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 238f308e7..95208141d 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -264,9 +264,11 @@ def test_serialize_custom_nlp(): model = nlp2.get_pipe("parser").model assert model.get_ref("tok2vec") is not None assert model.has_param("lower_W") - assert model.has_param("upper_W") assert model.has_param("lower_b") - assert model.has_param("upper_b") + upper = model.get_ref("upper") + assert upper is not None + assert upper.has_param("W") + assert upper.has_param("b") @pytest.mark.parametrize("parser_config_string", [parser_config_string_upper]) @@ -284,9 +286,11 @@ def test_serialize_parser(parser_config_string): model = nlp2.get_pipe("parser").model assert model.get_ref("tok2vec") is not None assert model.has_param("lower_W") - assert model.has_param("upper_W") assert model.has_param("lower_b") - assert model.has_param("upper_b") + upper = model.get_ref("upper") + assert upper is not None + assert upper.has_param("b") + assert upper.has_param("W") def test_config_nlp_roundtrip(): diff --git a/spacy/tests/training/test_rehearse.py b/spacy/tests/training/test_rehearse.py index 45829a01f..5ac7fc217 100644 --- a/spacy/tests/training/test_rehearse.py +++ b/spacy/tests/training/test_rehearse.py @@ -203,8 +203,7 @@ def _optimize(nlp, component: str, data: List, rehearse: bool): return nlp -# Fixme: reenable ner and parser when rehearsal is implemented. -@pytest.mark.parametrize("component", ["tagger", "textcat_multilabel"]) +@pytest.mark.parametrize("component", ["ner", "tagger", "parser", "textcat_multilabel"]) def test_rehearse(component): nlp = spacy.blank("en") nlp.add_pipe(component)