py
import fal
from pathlib import Path
DATA_DIR = Path("/data/mnist")
@fal.function(
"virtualenv",
requirements=["torch>=2.0.0", "torchvision"],
machine_type="M",
)
def train_fashion_model():
import torch
from torchvision import datasets
already_present = DATA_DIR.exists()
if already_present:
print("Test data is already downloaded, skipping download!")
test_data = datasets.FashionMNIST(
root=DATA_DIR,
train=False,
download=not already_present,
)
...
if __name__ == "__main__":
train_fashion_model()
当您第一次调用此函数时,您会注意到 Torch 会下载测试数据集。但是,后续调用(即使是调用的 keep_alive
未涵盖的调用)也会跳过下载并直接进入您的逻辑。
对于 HF 相关库,fal 确保所有下载的模型都保留,以避免在运行 ML 推理工作负载时重新下载。无需为transformers
或diffusers
自定义输出路径。
上次更新于 2024 年 6 月 21 日