You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
import logging
|
|
import numpy as np
|
|
import time
|
|
from typing import Optional
|
|
import cv2
|
|
import json
|
|
|
|
from tritonclient import utils as client_utils
|
|
from tritonclient.grpc import (
|
|
InferenceServerClient,
|
|
InferInput,
|
|
InferRequestedOutput,
|
|
service_pb2_grpc,
|
|
service_pb2,
|
|
)
|
|
|
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
|
|
|
|
|
class SyncGRPCTritonRunner:
|
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
model_name: str,
|
|
model_version: str,
|
|
*,
|
|
verbose=False,
|
|
resp_wait_s: Optional[float] = None,
|
|
):
|
|
self._server_url = server_url
|
|
self._model_name = model_name
|
|
self._model_version = model_version
|
|
self._verbose = verbose
|
|
self._response_wait_t = (
|
|
self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
|
)
|
|
|
|
self._client = InferenceServerClient(self._server_url, verbose=self._verbose)
|
|
error = self._verify_triton_state(self._client)
|
|
if error:
|
|
raise RuntimeError(f"Could not communicate to Triton Server: {error}")
|
|
|
|
LOGGER.debug(
|
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
|
f"are up and ready!"
|
|
)
|
|
|
|
model_config = self._client.get_model_config(
|
|
self._model_name, self._model_version
|
|
)
|
|
model_metadata = self._client.get_model_metadata(
|
|
self._model_name, self._model_version
|
|
)
|
|
LOGGER.info(f"Model config {model_config}")
|
|
LOGGER.info(f"Model metadata {model_metadata}")
|
|
|
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
|
self._input_names = list(self._inputs)
|
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
|
self._output_names = list(self._outputs)
|
|
self._outputs_req = [InferRequestedOutput(name) for name in self._outputs]
|
|
|
|
def Run(self, inputs):
|
|
"""
|
|
Args:
|
|
inputs: list, Each value corresponds to an input name of self._input_names
|
|
Returns:
|
|
results: dict, {name : numpy.array}
|
|
"""
|
|
infer_inputs = []
|
|
for idx, data in enumerate(inputs):
|
|
infer_input = InferInput(self._input_names[idx], data.shape, "UINT8")
|
|
infer_input.set_data_from_numpy(data)
|
|
infer_inputs.append(infer_input)
|
|
|
|
results = self._client.infer(
|
|
model_name=self._model_name,
|
|
model_version=self._model_version,
|
|
inputs=infer_inputs,
|
|
outputs=self._outputs_req,
|
|
client_timeout=self._response_wait_t,
|
|
)
|
|
results = {name: results.as_numpy(name) for name in self._output_names}
|
|
return results
|
|
|
|
def _verify_triton_state(self, triton_client):
|
|
if not triton_client.is_server_live():
|
|
return f"Triton server {self._server_url} is not live"
|
|
elif not triton_client.is_server_ready():
|
|
return f"Triton server {self._server_url} is not ready"
|
|
elif not triton_client.is_model_ready(self._model_name, self._model_version):
|
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_name = "pp_ocr"
|
|
model_version = "1"
|
|
url = "localhost:8001"
|
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
|
im = cv2.imread("12.jpg")
|
|
im = np.array(
|
|
[
|
|
im,
|
|
]
|
|
)
|
|
for i in range(1):
|
|
result = runner.Run(
|
|
[
|
|
im,
|
|
]
|
|
)
|
|
batch_texts = result["rec_texts"]
|
|
batch_scores = result["rec_scores"]
|
|
batch_bboxes = result["det_bboxes"]
|
|
for i_batch in range(len(batch_texts)):
|
|
texts = batch_texts[i_batch]
|
|
scores = batch_scores[i_batch]
|
|
bboxes = batch_bboxes[i_batch]
|
|
for i_box in range(len(texts)):
|
|
print(
|
|
"text=",
|
|
texts[i_box].decode("utf-8"),
|
|
" score=",
|
|
scores[i_box],
|
|
" bbox=",
|
|
bboxes[i_box],
|
|
)
|