Wallaroo SDK Essentials Guide: Model Uploads and Registrations: PyTorch
Table of Contents
Model Naming Requirements
Model names map onto Kubernetes objects, and must be DNS compliant. The strings for model names must be lower case ASCII alpha-numeric characters or dash (-) only. .
and _
are not allowed.
Parameter | Description |
---|---|
Web Site | https://pytorch.org/ |
Supported Libraries |
|
Framework | Framework.PYTORCH aka pytorch |
Supported File Types | pt ot pth in TorchScript format |
IMPORTANT NOTE
The PyTorch model must be in TorchScript format. scripting (i.e.torch.jit.script()
is always recommended over tracing (i.e. torch.jit.trace()
).From the PyTorch documentation: “Scripting preserves dynamic control flow and is valid for inputs of different sizes.”
For more details, see TorchScript-based ONNX Exporter: Tracing vs Scripting.
During the model upload process, Wallaroo optimizes models by converting them to the Wallaroo Native Runtime, if possible, or running the model directly in the Wallaroo Containerized Runtime. See the Model Deploy for details on how to configure pipeline resources based on the model’s runtime.
- IMPORTANT CONFIGURATION NOTE: For PyTorch input schemas, the floats must be
pyarrow.float32()
for the PyTorch model to be converted to the Native Wallaroo Runtime during the upload process.
PyTorch Input and Output Schemas
PyTorch input and output schemas have additional requirements depending on whether the PyTorch model is single input/output or multiple input/output. This refers to the number of columns:
- Single Input/Output: Has one input and one output column.
- Multiple Input/Output: Has more than one input or more than one output column.
The column names for the model can be anything. For example:
- Model Input Fields:
length
width
intensity
etc
When creating the input and output schemas for uploading a PyTorch model in Wallaroo, the field names must match the following requirements. For example, for multi-column PyTorch models, the input would be:
- Data Schema Input Fields:
input_1
input_2
input_3
input_...
For single input/output PyTorch model, the field names must be input
and output
. For example, if the input field is a List of Floats of size 10, and the output field is a list of floats of list size one, the input and output schemas are:
input_schema = pa.schema([
pa.field('input', pa.list_(pa.float32(), list_size=10))
])
output_schema = pa.schema([
pa.field('output', pa.list_(pa.float32(), list_size=1))
])
For multi input/output PyTorch models, the data schemas for each input and output field must be named input_1, input_2...
and output_1
, output_2
, etc. These must be in the same order that the PyTorch model is trained to accept them.
For example, a multi input/output PyTorch model that takes the following inputs and outputs:
- Inputs
input_1
: List of Floats of length 10.input_2
: List of Floats of length 5.
- Outputs
output_1
: List of Floats of length 3.output_2
: List of Floats of length 2.
The following input and output schemas would be used.
input_schema = pa.schema([
pa.field('input_1', pa.list_(pa.float32(), list_size=10)),
pa.field('input_2', pa.list_(pa.float32(), list_size=5))
])
output_schema = pa.schema([
pa.field('output_1', pa.list_(pa.float32(), list_size=3)),
pa.field('output_2', pa.list_(pa.float32(), list_size=2))
])
Uploading PyTorch Models
PyTorch models are uploaded to Wallaroo through the Wallaroo Client upload_model
method.
Upload PyTorch Model Parameters
The following parameters are required for PyTorch models. Note that while some fields are considered as optional for the upload_model
method, they are required for proper uploading of a PyTorch model to Wallaroo.
Parameter | Type | Description |
---|---|---|
name | string (Required) | The name of the model. Model names are unique per workspace. Models that are uploaded with the same name are assigned as a new version of the model. |
path | string (Required) | The path to the model file being uploaded. |
framework | string (Required) | Set as the Framework.PyTorch . |
input_schema | pyarrow.lib.Schema (Required) | The input schema in Apache Arrow schema format. Note that float values must be pyarrow.float32() for the Pytorch model to be converted to a Wallaroo Native Runtime during model upload. |
output_schema | pyarrow.lib.Schema (Required) | The output schema in Apache Arrow schema format. Note that float values must be pyarrow.float32() for the Pytorch model to be converted to a Wallaroo Native Runtime during model upload. |
convert_wait | bool (Optional) (Default: True) |
|
arch | wallaroo.engine_config.Architecture | The architecture the model is deployed to. If a model is intended for deployment to an ARM architecture, it must be specified during this step. Values include: X86 (Default): x86 based architectures. ARM : ARM based architectures. |
Once the upload process starts, the model is containerized by the Wallaroo instance. This process may take up to 10 minutes depending on the size and complexity of the model.
Upload PyTorch Model Return
upload_model
returns a wallaroo.model_version.ModelVersion
object with the following fields.
Field | Type | Description |
---|---|---|
name | String | The name of the model. |
version | String | The model version as a unique UUID. |
file_name | String | The file name of the model as stored in Wallaroo. |
SHA | String | The hash value of the model file. |
Status | String | The status of the model. |
image_path | String | The image used to deploy the model in the Wallaroo engine. |
last_update_time | DateTime | When the model was last updated. |
Upload PyTorch Model Example
The following example is of uploading a PyTorch ML Model to a Wallaroo instance.
input_schema = pa.schema(
[
pa.field('input', pa.list_(pa.float32(), list_size=10))
]
)
output_schema = pa.schema(
[
pa.field('output', pa.list_(pa.float32(), list_size=1))
]
)
model = wl.upload_model('pt-single-io-model',
"./models/model-auto-conversion_pytorch_single_io_model.pt",
framework=Framework.PYTORCH,
input_schema=input_schema,
output_schema=output_schema
)
display(model)
Waiting for model loading - this will take up to 10.0min.
Model is pending loading to a native runtime..
Ready