Advanced Text Classification: Multi-label, Zero-shot, and Few-shot
Text Classification is one of the most common NLP tasks, but in practice it goes well beyond the simple "positive or negative". A news article can simultaneously be politics, economics, and international affairs. A support ticket can belong to multiple categories at once. A document can be classified without ever having seen examples of that category during training.
In this article we tackle text classification in all its complexity: multi-class (multiple mutually exclusive classes), multi-label (multiple simultaneous labels), hierarchical classification, and zero-shot classification (without training examples for the target classes). We include implementations with Focal Loss for imbalanced datasets, SetFit for few-shot scenarios, and production-ready inference pipelines.
This is the sixth article in the Modern NLP: from BERT to LLMs series. It assumes familiarity with BERT and the HuggingFace ecosystem.
What You Will Learn
- Difference between binary, multi-class, and multi-label — when to use what
- Fine-tuning BERT for multi-class classification with composite metrics
- Multi-label classification with sigmoid, BCEWithLogitsLoss, and Focal Loss
- Per-label threshold tuning in multi-label with F1 optimization
- Zero-shot classification with NLI models and custom hypothesis templates
- Few-shot classification with SetFit: 8-64 examples per class
- Hierarchical classification with flat vs. top-down approaches
- Handling imbalanced datasets: class weighting, Focal Loss, oversampling
- Multi-label metrics: hamming loss, micro/macro F1, subset accuracy
- Production-ready classification pipeline with caching and batch inference
1. Text Classification Taxonomy
Choosing the correct classification type is the first fundamental step. The choice influences the loss function, evaluation metrics, and model architecture.
Types of Text Classification: Decision Guide
| Type | Description | Practical Example | Output Layer | Loss Function | Main Metric |
|---|---|---|---|---|---|
| Binary | 2 mutually exclusive classes | Spam vs Ham, Positive vs Negative | Sigmoid(1) | BCEWithLogitsLoss | F1, AUC-ROC |
| Multi-class | N classes, pick one | News category, language detection | Softmax(N) | CrossEntropyLoss | Accuracy, F1 macro |
| Multi-label | N classes, multiple active at once | Article tags, multiple emotions | Sigmoid(N) | BCEWithLogitsLoss | Hamming Loss, Micro F1 |
| Hierarchical | Classes organized in hierarchy | Product category (Electronics > TV > OLED) | Varies | Hierarchical loss | F1 per level |
| Zero-shot | Classes never seen in training | Routing on arbitrary topics | NLI entailment scores | NLI training loss | F1, accuracy per class |
| Few-shot (SetFit) | Few examples per class (8-64) | Domain-specific classification | Logistic head | Contrastive + CE | Accuracy, F1 |
2. Multi-class Classification with BERT
In multi-class, exactly one class is correct for each example. We use softmax as the output layer activation and CrossEntropyLoss as the loss function. BERT achieves state-of-the-art on benchmarks like AG News (~95%), yelp-full (~70%), and Yahoo Answers (~77%).
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TrainingArguments,
Trainer,
EarlyStoppingCallback
)
from datasets import load_dataset
import evaluate
import numpy as np
import torch
# AG News: classify news into 4 balanced categories
# World, Sports, Business, Sci/Tech
dataset = load_dataset("ag_news")
print("AG News dataset:", dataset)
# train: 120,000 examples (30,000 per class — balanced!)
# test: 7,600 examples
LABELS = ["World", "Sports", "Business", "Sci/Tech"]
num_labels = len(LABELS)
MODEL = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
tokenized.set_format("torch")
# Multi-class model: softmax over 4 classes
model = AutoModelForSequenceClassification.from_pretrained(
MODEL,
num_labels=num_labels,
id2label={i: l for i, l in enumerate(LABELS)},
label2id={l: i for i, l in enumerate(LABELS)}
)
# Composite metrics
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
probs = torch.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).numpy()
avg_confidence = probs.max(axis=1).mean()
return {
"accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
"f1_macro": f1.compute(predictions=preds, references=labels, average="macro")["f1"],
"f1_weighted": f1.compute(predictions=preds, references=labels, average="weighted")["f1"],
"avg_confidence": float(avg_confidence)
}
args = TrainingArguments(
output_dir="./results/bert-agnews",
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
fp16=True,
report_to="none",
seed=42
)
trainer = Trainer(
model=model, args=args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)
trainer.train()
# Expected results on AG News test:
# Accuracy: ~94-95%, F1 macro: ~94-95%
# Production-ready inference
from transformers import pipeline
clf_pipeline = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
texts = ["European Central Bank raises interest rates by 0.5pp", "Juventus wins Champions League"]
results = clf_pipeline(texts)
for text, pred in zip(texts, results):
print(f" '{text[:50]}' -> {pred['label']} ({pred['score']:.3f})")
3. Multi-label Classification
In multi-label, each example can have zero, one, or multiple active labels. The fundamental change from multi-class is in the output layer (sigmoid instead of softmax) and the loss function (BCEWithLogitsLoss instead of CrossEntropyLoss). Each class is treated as an independent binary problem.
3.1 Multi-label Dataset Preparation
from datasets import Dataset
from transformers import AutoTokenizer
import numpy as np
import pandas as pd
# Multi-label dataset: news articles with multiple simultaneous tags
data = {
"text": [
"The ECB raised interest rates to combat European inflation",
"Juventus beat Milan 2-1 in an exciting Champions League match",
"Apple unveiled the new iPhone with an advanced AI chip and camera",
"The government approved a new tax law amid political controversy",
"Climate crisis hits the economies of developing nations",
"Tesla invests $2B in solar panels, cutting CO2 emissions by 40%",
"The EU Commission proposes new AI regulation rules",
],
# Labels: [economics, politics, sports, technology, environment]
"labels": [
[1, 0, 0, 0, 0], # economics only
[0, 0, 1, 0, 0], # sports only
[0, 0, 0, 1, 0], # technology only
[1, 1, 0, 0, 0], # economics + politics
[1, 0, 0, 0, 1], # economics + environment
[1, 0, 0, 1, 1], # economics + technology + environment
[1, 1, 0, 1, 0], # economics + politics + technology
]
}
LABELS = ["economics", "politics", "sports", "technology", "environment"]
NUM_LABELS = len(LABELS)
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
def tokenize_multilabel(examples):
encoding = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
# Convert labels to float (required by BCEWithLogitsLoss)
encoding["labels"] = [
[float(l) for l in label_list]
for label_list in examples["labels"]
]
return encoding
dataset = Dataset.from_dict(data)
tokenized = dataset.map(tokenize_multilabel, batched=True, remove_columns=["text"])
tokenized.set_format("torch", columns=["input_ids", "attention_mask"])
# Label distribution analysis
labels_df = pd.DataFrame(data["labels"], columns=LABELS)
print("\nLabel distribution:")
for col in LABELS:
count = int(labels_df[col].sum())
pct = 100 * count / len(labels_df)
print(f" {col:<15s}: {count}/{len(labels_df)} examples ({pct:.0f}%)")
3.2 Multi-label Model with Custom Loss
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch import nn
import torch
import numpy as np
# =========================================
# Standard BCEWithLogitsLoss Trainer
# =========================================
class MultiLabelTrainer(Trainer):
"""Custom Trainer for multi-label with BCEWithLogitsLoss."""
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.float(), labels.float().to(logits.device))
return (loss, outputs) if return_outputs else loss
# =========================================
# Focal Loss for imbalanced multi-label datasets
# =========================================
class FocalLossMultiLabelTrainer(Trainer):
"""
Trainer with Focal Loss for imbalanced multi-label datasets.
Focal Loss down-weights easy examples and focuses on hard ones.
FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
gamma=2 is standard (Lin et al. 2017, RetinaNet paper)
"""
def __init__(self, *args, focal_gamma: float = 2.0, focal_alpha: float = 0.25, **kwargs):
super().__init__(*args, **kwargs)
self.focal_gamma = focal_gamma
self.focal_alpha = focal_alpha
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
probs = torch.sigmoid(logits.float())
labels_float = labels.float().to(logits.device)
bce_loss = nn.functional.binary_cross_entropy_with_logits(
logits.float(), labels_float, reduction='none'
)
pt = probs * labels_float + (1 - probs) * (1 - labels_float)
focal_weight = (1 - pt) ** self.focal_gamma
alpha_t = self.focal_alpha * labels_float + (1 - self.focal_alpha) * (1 - labels_float)
focal_loss = (alpha_t * focal_weight * bce_loss).mean()
return (focal_loss, outputs) if return_outputs else focal_loss
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-multilingual-cased",
num_labels=NUM_LABELS,
problem_type="multi_label_classification"
)
def compute_multilabel_metrics(eval_pred):
logits, labels = eval_pred
probs = torch.sigmoid(torch.tensor(logits)).numpy()
predictions = (probs >= 0.5).astype(int)
hamming = np.mean(predictions != labels)
exact_match = np.mean(np.all(predictions == labels, axis=1))
from sklearn.metrics import f1_score
micro_f1 = f1_score(labels, predictions, average='micro', zero_division=0)
macro_f1 = f1_score(labels, predictions, average='macro', zero_division=0)
return {
"hamming_loss": hamming,
"subset_accuracy": exact_match,
"micro_f1": micro_f1,
"macro_f1": macro_f1
}
args = TrainingArguments(
output_dir="./results/bert-multilabel",
num_train_epochs=5,
per_device_train_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True,
report_to="none"
)
# Use FocalLoss for imbalanced datasets, MultiLabelTrainer for balanced
trainer = FocalLossMultiLabelTrainer(
model=model, args=args,
train_dataset=tokenized,
compute_metrics=compute_multilabel_metrics,
focal_gamma=2.0,
focal_alpha=0.25
)
trainer.train()
3.3 Per-label Threshold Optimization
The default threshold (0.5) is not always optimal for every label in multi-label classification. You can optimize the threshold for each label separately to maximize F1. This is especially important with imbalanced datasets.
from sklearn.metrics import f1_score
import numpy as np
import torch
def find_optimal_thresholds(logits: np.ndarray, true_labels: np.ndarray,
thresholds=None, label_names=None) -> np.ndarray:
"""Find the optimal threshold per label that maximizes F1 on the validation set."""
if thresholds is None:
thresholds = np.arange(0.05, 0.95, 0.05)
probs = 1 / (1 + np.exp(-logits)) # sigmoid
n_labels = logits.shape[1]
optimal_thresholds = np.zeros(n_labels)
print("Searching optimal threshold per label:")
for label_idx in range(n_labels):
best_f1, best_threshold = 0, 0.5
for threshold in thresholds:
preds = (probs[:, label_idx] >= threshold).astype(int)
f1 = f1_score(true_labels[:, label_idx], preds, zero_division=0)
if f1 > best_f1:
best_f1, best_threshold = f1, threshold
optimal_thresholds[label_idx] = best_threshold
label_name = label_names[label_idx] if label_names else f"label_{label_idx}"
support = int(true_labels[:, label_idx].sum())
print(f" {label_name:<15s}: threshold={best_threshold:.2f}, F1={best_f1:.4f} (n={support})")
return optimal_thresholds
def predict_multilabel(texts: list, model, tokenizer, thresholds: np.ndarray,
label_names: list, batch_size: int = 32) -> list:
"""Multi-label prediction with per-label optimized thresholds."""
model.eval()
all_results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors='pt', truncation=True,
padding=True, max_length=128)
with torch.no_grad():
probs = torch.sigmoid(model(**inputs).logits).numpy()
for sample_probs in probs:
sample_results = [
{"label": label, "probability": float(prob)}
for prob, threshold, label in zip(sample_probs, thresholds, label_names)
if prob >= threshold
]
all_results.append(sorted(sample_results, key=lambda x: -x["probability"]))
return all_results
4. Zero-shot Classification
Zero-shot classification allows classifying texts into categories the model has never seen during training. It leverages models trained on Natural Language Inference (NLI): given a text and a hypothesis, the model predicts whether the hypothesis is true (entailment), false (contradiction), or uncertain (neutral).
The process: the text is used as "premise", the category as "hypothesis" (e.g., "This text is about economics"). The entailment score indicates how much the text belongs to that category.
from transformers import pipeline
# Best NLI models for zero-shot classification:
# - facebook/bart-large-mnli (English, excellent EN)
# - cross-encoder/nli-deberta-v3-large (most accurate, EN only)
# - MoritzLaurer/mDeBERTa-v3-base-mnli-xnli (multilingual: EN, IT, DE, FR...)
# - joeddav/xlm-roberta-large-xnli (multilingual alternative)
# English zero-shot
classifier_en = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=0 # use GPU if available
)
# Multilingual zero-shot
classifier_multi = pipeline(
"zero-shot-classification",
model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
)
# Single-label classification
text = "The Federal Reserve raised interest rates by 25 basis points in its October meeting."
categories = ["economics", "politics", "sports", "technology", "environment"]
result = classifier_en(text, candidate_labels=categories, multi_label=False)
print("Top categories (EN single-label):")
for label, score in zip(result['labels'][:3], result['scores'][:3]):
print(f" {label}: {score:.3f}")
# Multi-label zero-shot
text_multi = "Tesla invests $2B in solar panels and wind farms, reducing CO2 emissions."
result_multi = classifier_multi(text_multi, candidate_labels=categories, multi_label=True)
print("\nMulti-label zero-shot:")
for label, score in zip(result_multi['labels'], result_multi['scores']):
if score > 0.3:
print(f" {label}: {score:.3f}")
# =========================================
# Custom hypothesis templates
# =========================================
# Default: "This example is {label}."
# Custom template: more specific, often better results
text = "Apple's revenue declined 8% year-over-year due to weak iPhone demand."
templates = {
"default": "This example is {}.",
"news_specific": "This news article is about {}.",
"domain_specific": "This text discusses {} topics.",
}
for template_name, template in templates.items():
result = classifier_en(
text,
candidate_labels=["financial news", "technology news", "sports news"],
hypothesis_template=template
)
print(f"\nTemplate '{template_name}': {result['labels'][0]} ({result['scores'][0]:.3f})")
# Domain-specific template for legal documents
legal_text = "The court ruled in favor of the plaintiff regarding patent infringement."
legal_result = classifier_en(
legal_text,
candidate_labels=["intellectual property", "criminal law", "contract dispute", "employment law"],
hypothesis_template="This legal document concerns {} law."
)
print(f"\nLegal domain: {legal_result['labels'][0]} ({legal_result['scores'][0]:.3f})")
5. Few-shot Classification with SetFit
SetFit (Sentence Transformer Fine-Tuning) allows training accurate classifiers with very few examples per class (8-16 examples). The idea is simple: fine-tune a sentence transformer to recognize similar/dissimilar pairs using the few-shot dataset via contrastive learning, then train a simple logistic head on the resulting embeddings.
SetFit outperforms standard fine-tuning and GPT-3 few-shot on many benchmarks with 8 examples per class, using a much smaller model.
# pip install setfit
from setfit import SetFitModel, Trainer as SetFitTrainer, TrainingArguments as SetFitArgs
from datasets import Dataset
# Few-shot training data: only 8 examples per class
train_data = {
"text": [
# Economics (8 examples)
"Interest rates raised by 0.5% by the Fed",
"US GDP grew 1.3% in the second quarter",
"Wall Street closes up 2.8% after earnings reports",
"Inflation falls to 2.4% due to lower energy costs",
"Amazon announces 5,000 new jobs in tech division",
"Federal deficit exceeds 3% of GDP projection",
"US exports to Asian markets grew by 12%",
"Dollar strengthens against Euro and Yen",
# Sports (8 examples)
"Manchester City wins Premier League title with record points",
"Carlos Alcaraz reaches the Wimbledon final against Djokovic",
"Ferrari wins pole position at Monaco Grand Prix",
"England qualifies for the FIFA World Cup",
"Liverpool signs striker for 80 million euros",
"Chelsea draws 1-1 with Tottenham in London derby",
"Noah Lyles breaks American record in 100 meters",
"USA women's soccer team wins Olympic gold medal",
# Technology (8 examples)
"OpenAI releases new reasoning model with advanced capabilities",
"Apple unveils M4 chip with improved neural processing",
"Google acquires AI startup for $2 billion",
"Tesla autonomous driving reaches Level 4 certification",
"Meta introduces enhanced privacy controls for users",
"Samsung announces 2nm chips for 2025 production",
"Microsoft integrates Copilot AI deeply into Windows",
"5G network coverage reaches 75% of US population",
],
"label": [
0, 0, 0, 0, 0, 0, 0, 0, # economics = 0
1, 1, 1, 1, 1, 1, 1, 1, # sports = 1
2, 2, 2, 2, 2, 2, 2, 2 # technology = 2
]
}
test_data = {
"text": [
"The Fed keeps interest rates unchanged at 5.5%",
"Barcelona beats Liverpool 3-1 in Champions League",
"NVIDIA reaches $2 trillion market capitalization",
],
"label": [0, 1, 2]
}
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)
# Load SetFit model (multilingual)
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
labels=["economics", "sports", "technology"]
)
# Very fast training (few minutes even on CPU)
args = SetFitArgs(
batch_size=16,
num_epochs=1, # epochs for classification head
num_iterations=20, # number of contrastive pairs generated
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
report_to="none",
)
trainer = SetFitTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
metric="accuracy"
)
trainer.train()
# Inference
texts = [
"The government approved the federal budget for 2025",
"Inter beat Barcelona 2-0 in Champions League final",
"OpenAI presents the new reasoning model o3-pro"
]
predictions = model.predict(texts)
scores = model.predict_proba(texts)
print("\nSetFit predictions:")
label_names = ["economics", "sports", "technology"]
for text, pred, prob in zip(texts, predictions, scores):
print(f" '{text[:45]}...' -> {label_names[pred]} ({max(prob):.3f})")
6. Hierarchical Classification
In many real scenarios categories are organized in hierarchies. An article might be classified as "Technology > AI > NLP". There are two main approaches: flat (ignores hierarchy) and hierarchical (exploits the structure for better accuracy).
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from typing import Dict, Tuple
HIERARCHY = {
"Economics": ["Financial Markets", "Macroeconomics", "Trade", "Labor"],
"Politics": ["Domestic Policy", "Foreign Policy", "Elections", "Legislation"],
"Sports": ["Soccer", "Tennis", "Formula 1", "Athletics"],
"Technology": ["AI", "Smartphones", "Cloud", "Cybersecurity"]
}
class HierarchicalClassifier:
"""
Top-down hierarchical classifier.
Step 1: classify into coarse category (Economics, Politics, ...)
Step 2: classify into subcategory (Financial Markets, Macroeconomics, ...)
"""
def __init__(self, coarse_model_path: str, fine_models: Dict[str, str]):
self.tokenizer = AutoTokenizer.from_pretrained(coarse_model_path)
self.coarse_model = AutoModelForSequenceClassification.from_pretrained(coarse_model_path)
self.coarse_model.eval()
self.fine_models = {}
for cat, path in fine_models.items():
m = AutoModelForSequenceClassification.from_pretrained(path)
m.eval()
self.fine_models[cat] = m
def predict(self, text: str) -> dict:
"""Hierarchical prediction with confidence scores."""
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
# Step 1: coarse classification
with torch.no_grad():
coarse_probs = torch.softmax(self.coarse_model(**inputs).logits, dim=-1)[0]
coarse_id = coarse_probs.argmax().item()
coarse_label = self.coarse_model.config.id2label[coarse_id]
coarse_score = float(coarse_probs[coarse_id])
# Step 2: fine classification (if available for this category)
fine_label, fine_score = None, None
if coarse_label in self.fine_models:
with torch.no_grad():
fine_probs = torch.softmax(self.fine_models[coarse_label](**inputs).logits, dim=-1)[0]
fine_id = fine_probs.argmax().item()
fine_label = self.fine_models[coarse_label].config.id2label[fine_id]
fine_score = float(fine_probs[fine_id])
return {
"coarse": coarse_label,
"fine": fine_label,
"coarse_confidence": coarse_score,
"fine_confidence": fine_score,
"full_path": f"{coarse_label} > {fine_label}" if fine_label else coarse_label
}
print("HierarchicalClassifier defined!")
7. Complete Multi-label Evaluation Metrics
from sklearn.metrics import (
f1_score, precision_score, recall_score,
hamming_loss, accuracy_score, average_precision_score
)
import numpy as np
def multilabel_evaluation_report(y_true: np.ndarray, y_pred: np.ndarray,
y_proba: np.ndarray, label_names: list) -> dict:
"""Complete multi-label classification report."""
print("=" * 65)
print("MULTI-LABEL CLASSIFICATION REPORT")
print("=" * 65)
hl = hamming_loss(y_true, y_pred)
sa = accuracy_score(y_true, y_pred) # subset accuracy
micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
print(f"\n{'Hamming Loss':<25s}: {hl:.4f} (lower is better)")
print(f"{'Subset Accuracy':<25s}: {sa:.4f} (all labels must match)")
print(f"{'Micro F1':<25s}: {micro_f1:.4f} (label-frequency weighted)")
print(f"{'Macro F1':<25s}: {macro_f1:.4f} (unweighted average)")
print(f"{'Weighted F1':<25s}: {f1_score(y_true, y_pred, average='weighted', zero_division=0):.4f}")
if y_proba is not None:
try:
macro_auc = average_precision_score(y_true, y_proba, average='macro')
print(f"{'Macro AP (AUC-PR)':<25s}: {macro_auc:.4f}")
except Exception:
pass
print("\nPer-label metrics:")
header = f"{'Label':<18s} {'Precision':>10s} {'Recall':>10s} {'F1':>8s} {'Support':>10s}"
print(header)
print("-" * 65)
for i, label in enumerate(label_names):
prec = precision_score(y_true[:, i], y_pred[:, i], zero_division=0)
rec = recall_score(y_true[:, i], y_pred[:, i], zero_division=0)
f1 = f1_score(y_true[:, i], y_pred[:, i], zero_division=0)
support = int(y_true[:, i].sum())
print(f"{label:<18s} {prec:>10.4f} {rec:>10.4f} {f1:>8.4f} {support:>10d}")
return {"hamming_loss": hl, "subset_accuracy": sa, "micro_f1": micro_f1, "macro_f1": macro_f1}
8. Production-ready Classification Pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np
from typing import Union, List, Dict
import time
class ProductionClassifier:
"""
Production-ready classifier with:
- Batch inference for efficiency
- Multi-class and multi-label support
- Latency monitoring
- Confidence tracking
"""
def __init__(
self,
model_path: str,
task: str = "multi_class", # "multi_class" or "multi_label"
thresholds: np.ndarray = None,
max_length: int = 128,
batch_size: int = 32
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.task = task
self.label_names = list(self.model.config.id2label.values())
self.thresholds = thresholds if thresholds is not None else np.full(len(self.label_names), 0.5)
self.max_length = max_length
self.batch_size = batch_size
self._latencies = []
@torch.no_grad()
def predict(self, texts: Union[str, List[str]]) -> List[Dict]:
"""Prediction with latency monitoring."""
if isinstance(texts, str):
texts = [texts]
start = time.perf_counter()
all_results = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i+self.batch_size]
inputs = self.tokenizer(
batch, return_tensors='pt', truncation=True,
padding=True, max_length=self.max_length
).to(self.device)
outputs = self.model(**inputs)
logits = outputs.logits.cpu().numpy()
if self.task == "multi_class":
probs = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
for p in probs:
pred_id = p.argmax()
all_results.append({
"label": self.label_names[pred_id],
"score": float(p[pred_id]),
"all_scores": {name: float(s) for name, s in zip(self.label_names, p)}
})
else: # multi_label
probs = 1 / (1 + np.exp(-logits))
for p in probs:
labels = [
{"label": name, "score": float(score)}
for name, score, thr in zip(self.label_names, p, self.thresholds)
if score >= thr
]
all_results.append({"labels": sorted(labels, key=lambda x: -x["score"])})
self._latencies.append((time.perf_counter() - start) * 1000)
return all_results
def get_stats(self) -> Dict:
"""Inference latency statistics."""
if not self._latencies:
return {}
return {
"avg_latency_ms": np.mean(self._latencies),
"p99_latency_ms": np.percentile(self._latencies, 99),
"total_predictions": len(self._latencies)
}
print("ProductionClassifier ready for deployment!")
9. Classification with Generative Models (LLM Prompting)
With the rise of LLMs, text classification can now be performed simply through prompting, without any training. This approach is particularly useful for rapid prototyping and for new or rare categories that are hard to annotate.
from transformers import pipeline
import json
# =========================================
# Classification with LLM prompting
# =========================================
def classify_with_llm(text: str, categories: list, model_pipeline) -> dict:
"""
Zero-shot classification with an instruction-following LLM.
No fine-tuning needed: relies on natural language understanding.
Works well with models like Mistral-7B-Instruct, Llama-3, Phi-3.
"""
categories_str = ", ".join(categories)
prompt = f"""Classify the following text into ONE of these categories: {categories_str}.
Text: "{text}"
Respond ONLY with the category name, no explanations.
Category:"""
response = model_pipeline(
prompt,
max_new_tokens=20,
temperature=0.0, # deterministic
do_sample=False
)[0]['generated_text']
# Extract category from response
answer = response[len(prompt):].strip().split('\n')[0].strip()
# Validate that response is a valid category
for cat in categories:
if cat.lower() in answer.lower():
return {"label": cat, "method": "llm", "raw_answer": answer}
return {"label": "unknown", "method": "llm", "raw_answer": answer}
def classify_with_fewshot(text: str, categories: list, examples: list, model_pipeline) -> dict:
"""
Few-shot classification: provides examples in the prompt to guide the model.
Even 3-5 examples can dramatically improve LLM classification accuracy.
"""
examples_str = ""
for ex in examples[:3]: # max 3 examples to avoid context overflow
examples_str += f'Text: "{ex["text"]}"\nCategory: {ex["label"]}\n\n'
prompt = f"""Classify texts into: {", ".join(categories)}.
Examples:
{examples_str}Text: "{text}"
Category:"""
response = model_pipeline(
prompt, max_new_tokens=15, temperature=0.0, do_sample=False
)[0]['generated_text']
answer = response[len(prompt):].strip().split('\n')[0].strip()
return {"label": answer, "method": "few-shot-llm"}
# =========================================
# Comparison: zero-shot NLI vs LLM prompting vs BERT fine-tuned
# =========================================
comparison_table = [
{"method": "BERT fine-tuned", "F1": "0.95+", "speed": "fast", "data": "1000+ examples", "cost": "low"},
{"method": "SetFit (few-shot)", "F1": "0.85+", "speed": "fast", "data": "8-64 examples", "cost": "low"},
{"method": "Zero-shot NLI (BART)", "F1": "0.70+", "speed": "medium", "data": "zero examples", "cost": "low"},
{"method": "LLM prompting (7B local)", "F1": "0.75+", "speed": "slow", "data": "zero examples", "cost": "medium"},
{"method": "LLM few-shot (7B local)", "F1": "0.82+", "speed": "slow", "data": "3-10 examples", "cost": "medium"},
{"method": "GPT-4 prompting (API)", "F1": "0.88+", "speed": "very slow", "data": "zero examples", "cost": "high"},
]
print("=== Classification Methods Comparison ===")
print(f"{'Method':<35s} {'F1':<10s} {'Speed':<15s} {'Data Required':<20s} {'Cost'}")
print("-" * 90)
for row in comparison_table:
print(f"{row['method']:<35s} {row['F1']:<10s} {row['speed']:<15s} {row['data']:<20s} {row['cost']}")
print("\nRecommendation: start with zero-shot NLI to validate the task,")
print("then fine-tune BERT if you have labeled data, or use SetFit with a few annotated examples.")
Anti-Pattern: Using Accuracy as the Only Metric
With imbalanced datasets (e.g., 95% negative, 5% positive), a model that always predicts "negative" gets 95% accuracy but is completely useless. Always use F1, precision, and recall for binary and multi-class classification. For multi-label, use hamming loss and micro/macro F1. Never overlook the class distribution in your dataset.
Approach Selection Guide
| Scenario | Recommended Approach | Setup Time |
|---|---|---|
| Fixed categories, lots of data (>5K) | Standard BERT fine-tuning | Hours |
| Fixed categories, few data (<100) | SetFit (few-shot) | Minutes |
| Variable or unknown categories | Zero-shot NLI + custom templates | Instant |
| Multi-label, balanced dataset | BERT + BCEWithLogitsLoss + threshold tuning | Hours |
| Multi-label, imbalanced dataset | Focal Loss + per-label threshold tuning | Hours |
| Category hierarchy | Top-down HierarchicalClassifier | Days |
| Rapid prototyping | Zero-shot pipeline | Seconds |
10. Real-world Benchmarks and Model Selection
Choosing the right model for text classification depends on task type, data volume, latency requirements, and available hardware. Here is a practical benchmark comparing the most popular approaches across standard datasets.
Text Classification Benchmarks (2024-2025)
| Task | Dataset | Model | Accuracy / F1 | Notes |
|---|---|---|---|---|
| Binary Sentiment | SST-2 (EN) | DistilBERT fine-tuned | Acc 92.7% | 6x faster than BERT-base |
| Binary Sentiment | SST-2 (EN) | RoBERTa-large fine-tuned | Acc 96.4% | State of the art |
| Multi-class (6 classes) | AG News | BERT-base fine-tuned | Acc 94.8% | Standard benchmark |
| Multi-label | Reuters-21578 | RoBERTa + BCELoss | Micro-F1 89.2% | 90 categories |
| Zero-shot | Yahoo Answers | BART-large-MNLI | Acc 70.3% | No training data |
| Few-shot (8 examples) | SST-2 (EN) | SetFit (MiniLM) | Acc 88.1% | 8 labeled examples only |
| Italian Sentiment | SENTIPOLC 2016 | dbmdz BERT fine-tuned | F1 91.3% | Best Italian model |
# Quick benchmark script: compare multiple classifiers on your dataset
from transformers import pipeline
from sklearn.metrics import f1_score, accuracy_score
import time
import numpy as np
def benchmark_classifiers(test_texts: list, test_labels: list, label_names: list):
"""
Compare zero-shot NLI vs SetFit vs fine-tuned BERT
on a small test set to pick the best approach for your task.
"""
results = {}
# --- Zero-shot NLI (no training needed) ---
zs = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
start = time.perf_counter()
zs_preds = []
for text in test_texts:
out = zs(text, candidate_labels=label_names)
zs_preds.append(label_names.index(out['labels'][0]))
zs_time = (time.perf_counter() - start) * 1000
zs_f1 = f1_score(test_labels, zs_preds, average='macro', zero_division=0)
zs_acc = accuracy_score(test_labels, zs_preds)
results["zero-shot-NLI"] = {
"macro_f1": round(zs_f1, 4),
"accuracy": round(zs_acc, 4),
"total_time_ms": round(zs_time, 1),
"requires_training": False
}
print("=== Classifier Benchmark ===")
print(f"{'Model':<25s} {'Macro F1':>10s} {'Accuracy':>10s} {'Time (ms)':>12s} {'Training?':>12s}")
print("-" * 75)
for name, res in results.items():
print(f"{name:<25s} {res['macro_f1']:>10.4f} {res['accuracy']:>10.4f} "
f"{res['total_time_ms']:>12.1f} {str(res['requires_training']):>12s}")
return results
# Usage:
# test_texts = ["Apple reports record profits", "Italy wins World Cup", "New AI model beats GPT-4"]
# test_labels = [0, 1, 0] # 0=technology, 1=sports
# benchmark_classifiers(test_texts, test_labels, ["technology", "sports"])
print("Benchmark script ready!")
Conclusions and Next Steps
Modern text classification goes well beyond binary classification. Zero-shot, few-shot, and multi-label are real-world scenarios requiring specific approaches. With the tools covered in this article — from SetFit for few-shot to Focal Loss for imbalanced datasets, from zero-shot NLI to hierarchical classifiers — you have the foundation to tackle any text classification scenario in production.
Continue the Modern NLP Series
- Previous: Named Entity Recognition — entity extraction with BERT
- Next: HuggingFace Transformers: Complete Guide — ecosystem and Trainer API
- Article 8: Local LLM Fine-tuning — LoRA and QLoRA on consumer GPU
- Article 9: Semantic Similarity — sentence embeddings and FAISS for search
- Article 10: NLP Monitoring in Production — drift detection and retraining
- Related series: AI Engineering/RAG — zero-shot classification as routing in RAG
- Related series: Deep Learning Advanced — advanced classification architectures







