Skip to content

edsnlp.pipes.trainable.layers.crf

LinearChainCRF [source]

Bases: Module

A linear chain CRF in Pytorch

Parameters

PARAMETER DESCRIPTION
forbidden_transitions

Shape: n_tags * n_tags Impossible transitions (1 means impossible) from position n to position n+1

start_forbidden_transitions

Shape: n_tags Impossible transitions at the start of a sequence

DEFAULT: None

end_forbidden_transitions

Shape: n_tags Impossible transitions at the end of a sequence

DEFAULT: None

learnable_transitions

Should we learn transition scores to complete the constraints ?

DEFAULT: True

with_start_end_transitions

Should we apply start-end transitions. If learnable_transitions is True, learn start/end transition scores

DEFAULT: True

decode [source]

Decodes a sequence of tag scores using the Viterbi algorithm

Parameters

PARAMETER DESCRIPTION
emissions

Shape: ... * n_tokens * ... * n_tags

mask

Shape: ... * n_tokens

RETURNS DESCRIPTION
LongTensor

Backtrack indices (= argmax), ie best tag sequence

marginal [source]

Compute the marginal log-probabilities of the tags given the emissions and the transition probabilities and constraints of the CRF

We could use the propagate method but this implementation is faster.

Parameters

PARAMETER DESCRIPTION
emissions

Shape: ... * n_tokens * n_tags

mask

Shape: ... * n_tokens

RETURNS DESCRIPTION
FloatTensor

Shape: ... * n_tokens * n_tags

forward [source]

Compute the posterior reduced log-probabilities of the tags given the emissions and the transition probabilities and constraints of the CRF, ie the loss.

We could use the propagate method but this implementation is faster.

Parameters

PARAMETER DESCRIPTION
emissions

Shape: n_samples * n_tokens * ... * n_tags

mask

Shape: n_samples * n_tokens * ...

target

Shape: n_samples * n_tokens * ... * n_tags The target tags represented with 1-hot encoding We use 1-hot instead of long format to handle cases when multiple tags at a given position are allowed during training.

RETURNS DESCRIPTION
FloatTensor

Shape: ... The loss

MultiLabelBIOULDecoder [source]

Bases: LinearChainCRF

Create a linear chain CRF with hard constraints to enforce the BIOUL tagging scheme

Parameters

PARAMETER DESCRIPTION
num_labels

with_start_end_transitions

DEFAULT: True

learnable_transitions

DEFAULT: True

tags_to_spans staticmethod [source]

Convert a sequence of multiple label BIOUL tags to a sequence of spans

Parameters

PARAMETER DESCRIPTION
tags

Shape: n_samples * n_tokens * n_labels

RETURNS DESCRIPTION
LongTensor

Shape: n_spans * 4 (doc_idx, begin, end, label_idx)

try_torch_compile [source]

Call torch.compile(fn) when possible or fall back to fn if compiler errors occur.

We only fall back on torch compiler stack exceptions (Dynamo/Inductor/Functorch). All other exceptions are re-raised (they may indicate a real bug).