You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.1 KiB
Python

import logging
import numpy as np
import time
from typing import Optional
import cv2
import json
from tritonclient import utils as client_utils
from tritonclient.grpc import (
InferenceServerClient,
InferInput,
InferRequestedOutput,
service_pb2_grpc,
service_pb2,
)
LOGGER = logging.getLogger("run_inference_on_triton")
class SyncGRPCTritonRunner:
DEFAULT_MAX_RESP_WAIT_S = 120
def __init__(
self,
server_url: str,
model_name: str,
model_version: str,
*,
verbose=False,
resp_wait_s: Optional[float] = None,
):
self._server_url = server_url
self._model_name = model_name
self._model_version = model_version
self._verbose = verbose
self._response_wait_t = (
self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
)
self._client = InferenceServerClient(self._server_url, verbose=self._verbose)
error = self._verify_triton_state(self._client)
if error:
raise RuntimeError(f"Could not communicate to Triton Server: {error}")
LOGGER.debug(
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
f"are up and ready!"
)
model_config = self._client.get_model_config(
self._model_name, self._model_version
)
model_metadata = self._client.get_model_metadata(
self._model_name, self._model_version
)
LOGGER.info(f"Model config {model_config}")
LOGGER.info(f"Model metadata {model_metadata}")
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
self._input_names = list(self._inputs)
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
self._output_names = list(self._outputs)
self._outputs_req = [InferRequestedOutput(name) for name in self._outputs]
def Run(self, inputs):
"""
Args:
inputs: list, Each value corresponds to an input name of self._input_names
Returns:
results: dict, {name : numpy.array}
"""
infer_inputs = []
for idx, data in enumerate(inputs):
infer_input = InferInput(self._input_names[idx], data.shape, "UINT8")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
results = self._client.infer(
model_name=self._model_name,
model_version=self._model_version,
inputs=infer_inputs,
outputs=self._outputs_req,
client_timeout=self._response_wait_t,
)
results = {name: results.as_numpy(name) for name in self._output_names}
return results
def _verify_triton_state(self, triton_client):
if not triton_client.is_server_live():
return f"Triton server {self._server_url} is not live"
elif not triton_client.is_server_ready():
return f"Triton server {self._server_url} is not ready"
elif not triton_client.is_model_ready(self._model_name, self._model_version):
return f"Model {self._model_name}:{self._model_version} is not ready"
return None
if __name__ == "__main__":
model_name = "pp_ocr"
model_version = "1"
url = "localhost:8001"
runner = SyncGRPCTritonRunner(url, model_name, model_version)
im = cv2.imread("12.jpg")
im = np.array(
[
im,
]
)
for i in range(1):
result = runner.Run(
[
im,
]
)
batch_texts = result["rec_texts"]
batch_scores = result["rec_scores"]
batch_bboxes = result["det_bboxes"]
for i_batch in range(len(batch_texts)):
texts = batch_texts[i_batch]
scores = batch_scores[i_batch]
bboxes = batch_bboxes[i_batch]
for i_box in range(len(texts)):
print(
"text=",
texts[i_box].decode("utf-8"),
" score=",
scores[i_box],
" bbox=",
bboxes[i_box],
)