-
Notifications
You must be signed in to change notification settings - Fork 880
/
Copy pathotf_torch_message_handler.py
128 lines (110 loc) · 4.34 KB
/
otf_torch_message_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
OTF Codec for functionality requiring importing torch
"""
import io
import json
import logging
import os
import struct
from builtins import bytearray, bytes
import torch
from ts.protocol.otf_message_handler import encode_response_headers
from ts.utils.util import deprecated
def create_predict_response(
ret, req_id_map, message, code, context=None, ts_stream_next=False
):
"""
Create inference response.
:param context:
:param ret:
:param req_id_map:
:param message:
:param code:
:return:
"""
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = bytearray()
msg += struct.pack("!i", code)
buf = message.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
for idx in req_id_map:
req_id = req_id_map.get(idx).encode("utf-8")
msg += struct.pack("!i", len(req_id))
msg += req_id
if context is None:
# Encoding Content-Type
msg += struct.pack("!i", 0) # content_type
# Encoding the per prediction HTTP response code
# status code and reason phrase set to none
msg += struct.pack("!i", code)
msg += struct.pack("!i", 0) # No code phrase is returned
# Response headers none
msg += struct.pack("!i", 0)
else:
if ts_stream_next is True:
context.set_response_header(idx, "ts_stream_next", "true")
elif context.stopping_criteria:
is_stop = context.stopping_criteria[idx](ret[idx])
if is_stop is not None:
ts_stream_next = "false" if is_stop else "true"
context.set_response_header(idx, "ts_stream_next", ts_stream_next)
elif "true" == context.get_response_headers(idx).get("ts_stream_next"):
context.set_response_header(idx, "ts_stream_next", "false")
content_type = context.get_response_content_type(idx)
if content_type is None or len(content_type) == 0:
msg += struct.pack("!i", 0) # content_type
else:
msg += struct.pack("!i", len(content_type))
msg += content_type.encode("utf-8")
sc, phrase = context.get_response_status(idx)
http_code = sc if sc is not None else 200
http_phrase = phrase if phrase is not None else ""
msg += struct.pack("!i", http_code)
msg += struct.pack("!i", len(http_phrase))
msg += http_phrase.encode("utf-8")
# Response headers
msg += encode_response_headers(context.get_response_headers(idx))
if ret is None:
buf = b"error"
msg += struct.pack("!i", len(buf))
msg += buf
else:
val = ret[idx]
# NOTE: Process bytes/bytearray case before processing the string case.
if isinstance(val, (bytes, bytearray)):
msg += struct.pack("!i", len(val))
msg += val
elif isinstance(val, str):
buf = val.encode("utf-8")
msg += struct.pack("!i", len(buf))
msg += buf
elif isinstance(val, torch.Tensor):
buff = io.BytesIO()
torch.save(val, buff)
buff.seek(0)
val_bytes = buff.read()
msg += struct.pack("!i", len(val_bytes))
msg += val_bytes
else:
try:
json_value = json.dumps(val, indent=2).encode("utf-8")
msg += struct.pack("!i", len(json_value))
msg += json_value
except TypeError:
logging.warning("Unable to serialize model output.", exc_info=True)
return create_predict_response(
None, req_id_map, "Unsupported model output data type.", 503
)
msg += struct.pack("!i", -1) # End of list
return msg
@deprecated(
version=1.0,
replacement="ts.handler_utils.utils.send_intermediate_predict_response",
)
def send_intermediate_predict_response(ret, req_id_map, message, code, context=None):
if str(os.getenv("LOCAL_RANK", 0)) != "0":
return None
msg = create_predict_response(ret, req_id_map, message, code, context, True)
context.cl_socket.sendall(msg)