ZenML
PyTorch
All integrations

PyTorch

Accelerate your PyTorch model development with ZenML

Add to ZenML

Accelerate your PyTorch model development with ZenML

Seamlessly integrate PyTorch, a powerful deep learning framework, with ZenML to streamline your model development and experimentation process. By leveraging ZenML's model-agnostic pipelines and PyTorch's flexibility, you can rapidly iterate on models, track experiments, and deploy production-ready solutions with ease.

Features with ZenML

  • Seamless PyTorch Integration:
    Effortlessly incorporate PyTorch models and training logic into ZenML pipelines for a unified workflow.
  • Reproducible Experiments:
    Track and version PyTorch data objects and models using ZenML, ensuring reproducibility and facilitating collaboration.
  • Effortless Handling of PyTorch Data Artifacts and Models:
    ZenML knows how to serialize PyTorch artifacts like DataLoader and Module and allows you to use them across steps in different environments.
  • Streamlined Deployment:
    Seamlessly transition PyTorch models from experimentation to production using ZenML's deployment integrations.

PyTorch integration screenshot

Main Features

  • Flexible and expressive deep learning framework.
  • Extensive ecosystem of pre-trained models and extensions.
  • Optimizers, loss functions and other pre-defined helper classes to use out of the box.
  • Strong community support and comprehensive documentation.
  • Interoperability with popular data science tools and libraries.

How to use ZenML with PyTorch

from zenml import pipeline
from zenml.integrations.constants import PYTORCH
from torch import nn
from torch.utils.data import DataLoader


@step(enable_cache=False)
def trainer(train_dataloader: DataLoader) -> nn.Module:
    """Trains on the train dataloader."""
    model = NeuralNetwork().to(DEVICE)  # NeuralNetwork extends nn.Module
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    size = len(train_dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(DEVICE), y.to(DEVICE)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return model

@pipeline()
def fashion_mnist_pipeline():
    """Link all the steps and artifacts together."""
    train_dataloader, test_dataloader = importer_mnist()
    model = trainer(train_dataloader)
    evaluator(test_dataloader=test_dataloader, model=model)

Connect Your ML Pipelines to a World of Tools

Expand your ML pipelines with more than 50 ZenML Integrations

  • Amazon S3
  • Apache Airflow
  • Argilla
  • AutoGen
  • AWS
  • AWS Strands
  • Azure Blob Storage
  • Azure Container Registry
  • AzureML Pipelines
  • BentoML
  • Comet