Skip to content

edsnlp.pipes.trainable.doc_classifier.factory

create_component = registry.factory.register('eds.doc_classifier', assigns=['doc._.predicted_class'], deprecated=[])(TrainableDocClassifier) module-attribute [source]

The eds.doc_classifier component is a trainable document-level classifier. In this context, document classification consists in predicting one or more categorical labels at the document level (e.g. diagnosis code, discharge status, or any metadata derived from the whole document).

Unlike span classification, where predictions are attached to spans, the document classifier attaches predictions to the Doc object itself.

Architecture

The model performs multi-head document classification by:

  1. Calling a word/document embedding component eds.doc_pooler to compute a pooled embedding for the document.
  2. Feeding the pooled embedding into one or more classification heads. Each head is defined by a linear layer (optionally preceded by a head-specific hidden layer with activation, dropout, and layer norm).
  3. Computing independent logits for each head.
  4. Training with a per-head loss (CrossEntropy or Focal), optionally using class weights to handle imbalance.
  5. Aggregating head losses into a single training loss (simple average).
  6. During inference, assigning the predicted label for each head to doc._.labels[head_name].

Each classification head is independent, so different tasks (e.g. predicting ICD-10 category vs. mortality flag) can be trained jointly on the same pooled embeddings.

Examples

To create a document classifier component:

import edsnlp, edsnlp.pipes as eds

nlp = edsnlp.blank("eds")
nlp.add_pipe(
    eds.doc_classifier(
        label_attr=["icd10", "mortality"],
        labels={
            "icd10": "data/path_to_label_list_icd10.pkl",
            "mortality": "data/path_to_label_list_mortality.pkl",
        },
        num_classes={
            "icd10": 1000,
            "mortality": 2,
        },
        class_weights={
            "icd10": "data/path_to_class_weights_icd10.pkl",
            "mortality": "data/path_to_class_weights_mortality.pkl",
        },
        embedding=eds.doc_pooler(
            pooling_mode="attention",
            embedding=eds.transformer(
                model="almanach/camembertav2-base",
                window=256,
                stride=128,
            ),
        ),
        hidden_size=1024,
        activation_mode="relu",
        dropout_rate={
            "icd10": 0.05,
            "mortality": 0.2,
        },
        layer_norm=True,
        loss="ce",
    ),
    name="doc_classifier",
)

After training, predictions are stored in the Doc object:

doc = nlp("Patient was admitted with pneumonia and discharged alive.")
print(doc._.icd10, doc._.mortality)
# J18 alive

Parameters

PARAMETER DESCRIPTION
nlp

The spaCy/edsnlp pipeline the component belongs to.

TYPE: Optional[PipelineProtocol]

name

Component name.

TYPE: str DEFAULT: "doc_classifier"

embedding

Embedding component (e.g. transformer + pooling). Must expose an output_size attribute.

TYPE: WordEmbeddingComponent or WordContextualizerComponent

label_attr

List of head names. Each head corresponds to a document-level attribute (e.g. ["icd10", "mortality"]).

TYPE: List[str]

num_classes

Number of classes for each head. If not provided, inferred from labels.

TYPE: dict of str -> int DEFAULT: None

label2id

Per-head mapping from label string to integer ID.

TYPE: dict of str -> dict[str,int] DEFAULT: None

id2label

Reverse mapping (ID -> label string).

TYPE: dict of str -> dict[int,str] DEFAULT: None

loss

Loss type, either shared or per-head.

TYPE: "ce", "focal"} or dict[str, {"ce","focal" DEFAULT: ce

labels

Paths to pickle files containing label sets for each head.

TYPE: dict of str -> str (path) DEFAULT: None

class_weights

Paths to pickle files containing class frequency dicts (converted into class weights).

TYPE: dict of str -> str (path) DEFAULT: None

hidden_size

Hidden layer size (before classifier), shared or per-head. If None, no hidden layer is used.

TYPE: int or dict[str, int] DEFAULT: None

activation_mode

Activation function for hidden layers, shared or per-head.

TYPE: (relu, gelu, silu) DEFAULT: relu

dropout_rate

Dropout rate after activation, shared or per-head.

TYPE: float or dict[str, float] DEFAULT: 0.0

layer_norm

Whether to apply layer normalization in hidden layers, shared or per-head.

TYPE: bool or dict[str, bool] DEFAULT: False