Accelerate and simplify model training with Lightning AI Studio and ZenML
Integrating Lightning AI with ZenML enables data scientists and ML engineers to streamline the model development process. By using the Lightning AI integration, you can run your ZenML pipelines on Lightning AI’s infrastructure, leveraging its managed environments and scalable compute. This integration simplifies code, improves reproducibility, and allows seamless scaling of model training across different hardware.
You have to first set up your stack to include a Lightning AI orchestrator. Run the following commands after replacing the values with your own.
zenml orchestrator register lightning_orchestrator \
--flavor=lightning \
--user_id=<YOUR_LIGHTNING_USER_ID> \
--api_key=<YOUR_LIGHTNING_API_KEY> \
--username=<YOUR_LIGHTNING_USERNAME> \
--teamspace=<YOUR_LIGHTNING_TEAMSPACE> \
--organization=<YOUR_LIGHTNING_ORGANIZATION>
# Register and activate a stack with the new orchestrator
zenml stack register lightning_stack -o lightning_orchestrator ... --setYou can also define settings inside your pipeline code and pass it to the settings parameter of your pipeline. Find out all the values you can set from our code docs.
from zenml import pipeline, step
from zenml.integrations.lightning.flavors.lightning_orchestrator_flavor import (
LightningOrchestratorSettings,
)
lightning_settings = LightningOrchestratorSettings(
main_studio_name="my_studio", # change this to your studio name if you already have one
machine_type="cpu",
async_mode=True,
custom_commands=["pip install -r requirements.txt"],
)
@step
def load_data() -> dict:
"""Simulates loading of training data and labels."""
training_data = [[1, 2], [3, 4], [5, 6]]
labels = [0, 1, 0]
return {"features": training_data, "labels": labels}
@step
def train_model(data: dict) -> None:
"""
A mock 'training' process that also demonstrates using the input data.
In a real-world scenario, this would be replaced with actual model fitting logic.
"""
total_features = sum(map(sum, data["features"]))
total_labels = sum(data["labels"])
print(
f"Trained model using {len(data['features'])} data points. "
f"Feature sum is {total_features}, label sum is {total_labels}"
)
@pipeline(settings={"orchestrator": lightning_settings})
def simple_ml_pipeline():
"""Define a pipeline that connects the steps."""
dataset = load_data()
train_model(dataset)
if __name__ == "__main__":
run = simple_ml_pipeline()
# You can now use the `run` object to see steps, outputs, etc.Expand your ML pipelines with more than 50 ZenML Integrations