Source code for slashml.model_deployment

import json
import requests
import time
import tarfile
import truss
from enum import Enum
from slashml.utils import generateURL, baseUrl, generateHeaders, formatResponse, getTaskStatus
import typing as t


import os


[docs]class ModelDeployment: _base_url = baseUrl("model-deployment", "v1") _headers = None def __init__(self, api_key: str = None): if api_key==None: raise Exception("API Key is required for model deployment") self._headers = generateHeaders(api_key)
[docs] def create_tar_gz(self, *, folder_path, tar_gz_filename): with tarfile.open(tar_gz_filename, "w:gz") as tar: tar.add(folder_path, arcname=os.path.basename(folder_path))
[docs] def deploy(self, *, model_name:str, model: str, requirements:t.Optional[list] = None): """Submit job""" requirements_file_path = None if requirements: with open('requirements.txt', 'w') as f: for item in requirements: f.write("%s\n" % item) requirements_file_path = 'requirements.txt' truss.create(model, 'my_model', requirements_file=requirements_file_path) self.create_tar_gz(folder_path='my_model', tar_gz_filename='my_model.tar.gz') url = generateURL(self._base_url, "models") files = [("model_file", ("my_model.tar.gz", open('my_model.tar.gz', "rb"), "application/octet-stream"))] payload = { "model_name": model_name, } response = requests.post(url, headers=self._headers, data=payload, files=files) # remove requirements.txt if requirements_file_path: os.remove(requirements_file_path) return formatResponse(response)
[docs] def status(self, *, model_version_id: str): """Check job status""" url = generateURL(self._base_url, "models", model_version_id, "status") response = requests.get(url, headers=self._headers) return formatResponse(response)
[docs] def predict(self, model_version_id: str, model_input:str): """Check job status""" payload = json.dumps({ "model_input": model_input }) url = generateURL(self._base_url, "models", model_version_id, "predict") self._headers['Content-Type'] = 'application/json' response = requests.post(url, headers=self._headers, data=payload) return formatResponse(response)