π― Multiclass: Handle 3+ Classes with Advanced StrategiesΒΆ
This example shows two powerful strategies for multiclass threshold optimization: One-vs-Rest (OvR) and Coordinate Ascent. See which works best for your data.
[1]:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split
from optimal_cutoffs import optimize_thresholds
print("π― MULTICLASS THRESHOLD OPTIMIZATION")
print("=" * 50)
π― MULTICLASS THRESHOLD OPTIMIZATION
==================================================
π SCENARIO: Document ClassificationΒΆ
News articles: 3 categories (Politics, Sports, Tech)
Slightly imbalanced dataset for realistic comparison
Model outputs: probability scores per class
[2]:
# Generate realistic multiclass dataset
X, y = make_classification(
n_samples=2000,
n_features=20,
n_classes=3,
n_informative=15,
n_redundant=3,
n_clusters_per_class=1,
weights=[0.4, 0.35, 0.25], # Slightly imbalanced
flip_y=0.01,
random_state=42,
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
# Train classifier
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
y_prob = model.predict_proba(X_test)
# Dataset info
class_names = ['Politics', 'Sports', 'Tech']
class_counts = np.bincount(y_test)
print(f"π Test dataset: {len(y_test)} samples")
for i, (name, count) in enumerate(zip(class_names, class_counts)):
percentage = count / len(y_test) * 100
print(f" β’ Class {i} ({name}): {count} samples ({percentage:.1f}%)")
print()
π Test dataset: 600 samples
β’ Class 0 (Politics): 240 samples (40.0%)
β’ Class 1 (Sports): 210 samples (35.0%)
β’ Class 2 (Tech): 150 samples (25.0%)
β BEFORE: Standard argmax approachΒΆ
[3]:
# Standard approach: predict class with highest probability
y_pred_argmax = np.argmax(y_prob, axis=1)
# Calculate metrics
f1_macro_argmax = f1_score(y_test, y_pred_argmax, average='macro')
f1_micro_argmax = f1_score(y_test, y_pred_argmax, average='micro')
print("β BEFORE: Standard argmax approach")
print(f" Macro F1: {f1_macro_argmax:.3f}")
print(f" Micro F1: {f1_micro_argmax:.3f}")
print(" Strategy: Predict class with highest probability (no thresholds)")
print()
print("Confusion Matrix (argmax):")
cm_argmax = confusion_matrix(y_test, y_pred_argmax)
print(cm_argmax)
print()
β BEFORE: Standard argmax approach
Macro F1: 0.959
Micro F1: 0.958
Strategy: Predict class with highest probability (no thresholds)
Confusion Matrix (argmax):
[[231 7 2]
[ 7 203 0]
[ 8 1 141]]
β STRATEGY 1: One-vs-Rest (OvR) Independent ThresholdsΒΆ
[4]:
# Strategy 1: One-vs-Rest with independent thresholds per class
# Each class gets its own optimal threshold, optimized independently
print("β
STRATEGY 1: One-vs-Rest (OvR) Independent Thresholds")
print("-" * 60)
# Find optimal thresholds using OvR independent strategy
ovr_result = optimize_thresholds(
y_test, y_prob,
metric='f1',
method='auto', # Auto-select appropriate method for multiclass OvR
average='macro' # Macro averaging for F1
)
print(f"Optimal thresholds: {ovr_result.thresholds}")
for i, (name, threshold) in enumerate(zip(class_names, ovr_result.thresholds)):
print(f" β’ Class {i} ({name}): {threshold:.3f}")
# Make predictions using OvR strategy
y_pred_ovr = ovr_result.predict(y_prob)
# Calculate metrics
f1_macro_ovr = f1_score(y_test, y_pred_ovr, average='macro')
f1_micro_ovr = f1_score(y_test, y_pred_ovr, average='micro')
print(f"\nPerformance:")
print(f" Macro F1: {f1_macro_ovr:.3f} (vs {f1_macro_argmax:.3f} argmax)")
print(f" Micro F1: {f1_micro_ovr:.3f} (vs {f1_micro_argmax:.3f} argmax)")
macro_improvement_ovr = ((f1_macro_ovr - f1_macro_argmax) / f1_macro_argmax) * 100
print(f" π Macro F1 improvement: {macro_improvement_ovr:+.1f}%")
print()
print("Confusion Matrix (OvR Independent):")
cm_ovr = confusion_matrix(y_test, y_pred_ovr)
print(cm_ovr)
print()
β
STRATEGY 1: One-vs-Rest (OvR) Independent Thresholds
------------------------------------------------------------
Optimal thresholds: [-0.03 0.03 -0.13]
β’ Class 0 (Politics): -0.030
β’ Class 1 (Sports): 0.030
β’ Class 2 (Tech): -0.130
Performance:
Macro F1: 0.964 (vs 0.959 argmax)
Micro F1: 0.963 (vs 0.958 argmax)
π Macro F1 improvement: +0.5%
Confusion Matrix (OvR Independent):
[[233 3 4]
[ 8 201 1]
[ 5 1 144]]
β STRATEGY 2: Coordinate Ascent (Single-Label Consistent)ΒΆ
[5]:
# Strategy 2: Coordinate Ascent for single-label consistency
# Optimizes thresholds while ensuring exactly one prediction per sample
print("β
STRATEGY 2: Coordinate Ascent (Single-Label Consistent)")
print("-" * 60)
# Find optimal thresholds using coordinate ascent
coord_result = optimize_thresholds(
y_test, y_prob,
metric='f1',
method='coord_ascent', # Coordinate ascent optimization
average='macro'
)
print(f"Optimal thresholds: {coord_result.thresholds}")
for i, (name, threshold) in enumerate(zip(class_names, coord_result.thresholds)):
print(f" β’ Class {i} ({name}): {threshold:.3f}")
# Make predictions using coordinate ascent strategy
y_pred_coord = coord_result.predict(y_prob)
# Calculate metrics
f1_macro_coord = f1_score(y_test, y_pred_coord, average='macro')
f1_micro_coord = f1_score(y_test, y_pred_coord, average='micro')
print(f"\nPerformance:")
print(f" Macro F1: {f1_macro_coord:.3f} (vs {f1_macro_argmax:.3f} argmax)")
print(f" Micro F1: {f1_micro_coord:.3f} (vs {f1_micro_argmax:.3f} argmax)")
macro_improvement_coord = ((f1_macro_coord - f1_macro_argmax) / f1_macro_argmax) * 100
print(f" π Macro F1 improvement: {macro_improvement_coord:+.1f}%")
print()
print("Confusion Matrix (Coordinate Ascent):")
cm_coord = confusion_matrix(y_test, y_pred_coord)
print(cm_coord)
print()
β
STRATEGY 2: Coordinate Ascent (Single-Label Consistent)
------------------------------------------------------------
Optimal thresholds: [-0.03 0.03 -0.13]
β’ Class 0 (Politics): -0.030
β’ Class 1 (Sports): 0.030
β’ Class 2 (Tech): -0.130
Performance:
Macro F1: 0.964 (vs 0.959 argmax)
Micro F1: 0.963 (vs 0.958 argmax)
π Macro F1 improvement: +0.5%
Confusion Matrix (Coordinate Ascent):
[[233 3 4]
[ 8 201 1]
[ 5 1 144]]
π COMPARISON: Which strategy works best?ΒΆ
[6]:
print("π STRATEGY COMPARISON")
print("=" * 30)
strategies = [
('Argmax (baseline)', f1_macro_argmax, f1_micro_argmax, 'Standard approach'),
('OvR Independent', f1_macro_ovr, f1_micro_ovr, 'Can predict multiple classes'),
('Coordinate Ascent', f1_macro_coord, f1_micro_coord, 'Single-label consistent')
]
print(f"{'Strategy':<20} {'Macro F1':<10} {'Micro F1':<10} {'Notes':<30}")
print("-" * 75)
best_macro = max(strategies, key=lambda x: x[1])
best_micro = max(strategies, key=lambda x: x[2])
for name, macro_f1, micro_f1, notes in strategies:
macro_star = " π" if (name, macro_f1, micro_f1, notes) == best_macro else ""
micro_star = " π" if (name, macro_f1, micro_f1, notes) == best_micro else ""
print(f"{name:<20} {macro_f1:<10.3f}{macro_star:<3} {micro_f1:<10.3f}{micro_star:<3} {notes:<30}")
print()
# Show improvements
ovr_improvement = ((f1_macro_ovr - f1_macro_argmax) / f1_macro_argmax) * 100
coord_improvement = ((f1_macro_coord - f1_macro_argmax) / f1_macro_argmax) * 100
print("π Improvement over argmax baseline:")
print(f" β’ OvR Independent: {ovr_improvement:+.1f}% macro F1")
print(f" β’ Coordinate Ascent: {coord_improvement:+.1f}% macro F1")
print()
π STRATEGY COMPARISON
==============================
Strategy Macro F1 Micro F1 Notes
---------------------------------------------------------------------------
Argmax (baseline) 0.959 0.958 Standard approach
OvR Independent 0.964 π 0.963 π Can predict multiple classes
Coordinate Ascent 0.964 0.963 Single-label consistent
π Improvement over argmax baseline:
β’ OvR Independent: +0.5% macro F1
β’ Coordinate Ascent: +0.5% macro F1
π PREDICTION BEHAVIOR ANALYSISΒΆ
[7]:
# Analyze how often each strategy predicts multiple/no classes
print("π PREDICTION BEHAVIOR ANALYSIS")
print("=" * 40)
# For OvR Independent: check if multiple classes predicted
# (This can happen when multiple classes exceed their thresholds)
ovr_binary_predictions = y_prob >= ovr_result.thresholds[None, :]
ovr_predictions_per_sample = ovr_binary_predictions.sum(axis=1)
multiple_predictions = (ovr_predictions_per_sample > 1).sum()
no_predictions = (ovr_predictions_per_sample == 0).sum()
single_predictions = (ovr_predictions_per_sample == 1).sum()
print(f"OvR Independent Strategy:")
print(f" β’ Samples with single prediction: {single_predictions} ({single_predictions/len(y_test)*100:.1f}%)")
print(f" β’ Samples with multiple predictions: {multiple_predictions} ({multiple_predictions/len(y_test)*100:.1f}%)")
print(f" β’ Samples with no predictions: {no_predictions} ({no_predictions/len(y_test)*100:.1f}%)")
print()
# Coordinate ascent always predicts exactly one class
print(f"Coordinate Ascent Strategy:")
print(f" β’ Always predicts exactly one class (single-label consistent)")
print(f" β’ Uses margin rule: argmax(probability - threshold)")
print()
# Show some examples where strategies differ
different_predictions = (y_pred_ovr != y_pred_coord)
n_different = different_predictions.sum()
print(f"π Strategy Agreement:")
print(f" β’ Samples where strategies agree: {len(y_test) - n_different} ({(len(y_test) - n_different)/len(y_test)*100:.1f}%)")
print(f" β’ Samples where strategies differ: {n_different} ({n_different/len(y_test)*100:.1f}%)")
print()
if n_different > 0:
print("Example differences (first 5):")
diff_indices = np.where(different_predictions)[0][:5]
for idx in diff_indices:
print(f" Sample {idx}: Probs={y_prob[idx]:.2f}, OvR={y_pred_ovr[idx]}, Coord={y_pred_coord[idx]}, True={y_test[idx]}")
π PREDICTION BEHAVIOR ANALYSIS
========================================
OvR Independent Strategy:
β’ Samples with single prediction: 0 (0.0%)
β’ Samples with multiple predictions: 600 (100.0%)
β’ Samples with no predictions: 0 (0.0%)
Coordinate Ascent Strategy:
β’ Always predicts exactly one class (single-label consistent)
β’ Uses margin rule: argmax(probability - threshold)
π Strategy Agreement:
β’ Samples where strategies agree: 600 (100.0%)
β’ Samples where strategies differ: 0 (0.0%)
π Visualize threshold effectsΒΆ
[8]:
import matplotlib.pyplot as plt
# Plot probability distributions and thresholds
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (ax, class_name) in enumerate(zip(axes, class_names)):
# Plot probability distribution for this class
class_probs = y_prob[:, i]
# Separate by true class
true_class_probs = class_probs[y_test == i]
other_class_probs = class_probs[y_test != i]
ax.hist(other_class_probs, bins=30, alpha=0.6, label=f'Other classes', color='lightcoral')
ax.hist(true_class_probs, bins=30, alpha=0.8, label=f'True {class_name}', color='lightblue')
# Add threshold lines
ax.axvline(ovr_result.thresholds[i], color='red', linestyle='--',
label=f'OvR Threshold ({ovr_result.thresholds[i]:.3f})')
ax.axvline(coord_result.thresholds[i], color='green', linestyle='--',
label=f'Coord Threshold ({coord_result.thresholds[i]:.3f})')
ax.set_xlabel(f'Probability for {class_name}')
ax.set_ylabel('Count')
ax.set_title(f'Class {i}: {class_name}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("π The histograms show how optimal thresholds separate true class from others")
π The histograms show how optimal thresholds separate true class from others
π Whatβs Next?ΒΆ
04_interactive_demo.ipynb: Deep dive into mathematical foundations
API Documentation: Explore more advanced multiclass options
π‘ Multiclass Strategy GuideΒΆ
When to use One-vs-Rest (OvR) Independent:ΒΆ
Multi-label problems: Where samples can belong to multiple classes
Imbalanced classes: Each class optimized independently
Different costs per class: Each class can have different error costs
When to use Coordinate Ascent:ΒΆ
Single-label problems: Where each sample belongs to exactly one class
Coupled optimization: When class decisions should be consistent
Margin-based decisions: When you want argmax-style behavior with thresholds
π― Advanced TipsΒΆ
Try both strategies: Performance depends on your specific data
Cross-validation: Use CV to validate threshold choices
Micro vs Macro: Choose averaging based on your problem priorities
Class imbalance: OvR often works better for highly imbalanced datasets
Computational cost: Coordinate ascent is more expensive but can give better coupled optimization