26
26
# include < unistd.h>
27
27
#endif
28
28
#include < cstring>
29
+ #include < fstream>
30
+ #include < filesystem>
31
+
32
+ namespace fs = std::filesystem;
29
33
30
34
#ifdef _WIN32
31
35
typedef SOCKET sockfd_t ;
@@ -80,6 +84,7 @@ enum rpc_cmd {
80
84
RPC_CMD_FREE_BUFFER,
81
85
RPC_CMD_BUFFER_CLEAR,
82
86
RPC_CMD_SET_TENSOR,
87
+ RPC_CMD_SET_TENSOR_HASH,
83
88
RPC_CMD_GET_TENSOR,
84
89
RPC_CMD_COPY_TENSOR,
85
90
RPC_CMD_GRAPH_COMPUTE,
@@ -89,6 +94,9 @@ enum rpc_cmd {
89
94
RPC_CMD_COUNT,
90
95
};
91
96
97
+ // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
98
+ const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
99
+
92
100
struct rpc_msg_get_alloc_size_req {
93
101
rpc_tensor tensor;
94
102
};
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
135
143
uint8_t value;
136
144
};
137
145
146
+ struct rpc_msg_set_tensor_hash_rsp {
147
+ uint8_t result;
148
+ };
149
+
138
150
struct rpc_msg_get_tensor_req {
139
151
rpc_tensor tensor;
140
152
uint64_t offset;
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
187
199
188
200
// RPC helper functions
189
201
202
+ // Computes FNV-1a hash of the data
203
+ static uint64_t fnv_hash (const uint8_t * data, size_t len) {
204
+ const uint64_t fnv_prime = 0x100000001b3ULL ;
205
+ uint64_t hash = 0xcbf29ce484222325ULL ;
206
+
207
+ for (size_t i = 0 ; i < len; ++i) {
208
+ hash ^= data[i];
209
+ hash *= fnv_prime;
210
+ }
211
+ return hash;
212
+ }
213
+
190
214
static std::shared_ptr<socket_t > make_socket (sockfd_t fd) {
191
215
#ifdef _WIN32
192
216
if (fd == INVALID_SOCKET) {
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
483
507
484
508
static void ggml_backend_rpc_buffer_set_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
485
509
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
486
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
510
+ rpc_tensor rpc_tensor = serialize_tensor (tensor);
511
+ if (size > HASH_THRESHOLD) {
512
+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
513
+ size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
514
+ std::vector<uint8_t > input (input_size, 0 );
515
+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
516
+ memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
517
+ memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
518
+ memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
519
+ rpc_msg_set_tensor_hash_rsp response;
520
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input.data (), input.size (), &response, sizeof (response));
521
+ GGML_ASSERT (status);
522
+ if (response.result ) {
523
+ // the server has the same data, no need to send it
524
+ return ;
525
+ }
526
+ }
527
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
487
528
size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + size;
488
529
std::vector<uint8_t > input (input_size, 0 );
489
- rpc_tensor rpc_tensor = serialize_tensor (tensor);
490
530
memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
491
531
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
492
532
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
@@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
772
812
773
813
class rpc_server {
774
814
public:
775
- rpc_server (ggml_backend_t backend) : backend(backend) {}
815
+ rpc_server (ggml_backend_t backend, const char * cache_dir)
816
+ : backend(backend), cache_dir(cache_dir) {
817
+ }
776
818
~rpc_server ();
777
819
778
820
void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
@@ -782,13 +824,15 @@ class rpc_server {
782
824
bool free_buffer (const rpc_msg_free_buffer_req & request);
783
825
bool buffer_clear (const rpc_msg_buffer_clear_req & request);
784
826
bool set_tensor (const std::vector<uint8_t > & input);
827
+ bool set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response);
785
828
bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
786
829
bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
787
830
bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
788
831
bool init_tensor (const rpc_msg_init_tensor_req & request);
789
832
bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
790
833
791
834
private:
835
+ bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
792
836
ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
793
837
ggml_tensor * create_node (uint64_t id,
794
838
struct ggml_context * ctx,
@@ -797,6 +841,7 @@ class rpc_server {
797
841
798
842
799
843
ggml_backend_t backend;
844
+ const char * cache_dir;
800
845
std::unordered_set<ggml_backend_buffer_t > buffers;
801
846
};
802
847
@@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
960
1005
}
961
1006
962
1007
const void * data = input.data () + sizeof (rpc_tensor) + sizeof (offset);
1008
+ if (cache_dir && size > HASH_THRESHOLD) {
1009
+ uint64_t hash = fnv_hash ((const uint8_t *)data, size);
1010
+ char hash_str[17 ];
1011
+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1012
+ // save to cache_dir/hash_str
1013
+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1014
+ std::ofstream ofs (cache_file, std::ios::binary);
1015
+ ofs.write ((const char *)data, size);
1016
+ printf (" [%s] saved to '%s'\n " , __func__, cache_file.c_str ());
1017
+ }
963
1018
ggml_backend_tensor_set (tensor, data, offset, size);
964
1019
ggml_free (ctx);
965
1020
return true ;
966
1021
}
967
1022
1023
+ bool rpc_server::get_cached_file (uint64_t hash, std::vector<uint8_t > & data) {
1024
+ if (!cache_dir) {
1025
+ return false ;
1026
+ }
1027
+ char hash_str[17 ];
1028
+ snprintf (hash_str, sizeof (hash_str), " %016" PRIx64, hash);
1029
+ fs::path cache_file = fs::path (cache_dir) / hash_str;
1030
+ if (!fs::exists (cache_file)) {
1031
+ return false ;
1032
+ }
1033
+ std::ifstream ifs (cache_file, std::ios::binary);
1034
+ ifs.seekg (0 , std::ios::end);
1035
+ size_t size = ifs.tellg ();
1036
+ ifs.seekg (0 , std::ios::beg);
1037
+ data.resize (size);
1038
+ ifs.read ((char *)data.data (), size);
1039
+ return true ;
1040
+ }
1041
+
1042
+ bool rpc_server::set_tensor_hash (const std::vector<uint8_t > & input, rpc_msg_set_tensor_hash_rsp & response)
1043
+ {
1044
+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1045
+ if (input.size () != sizeof (rpc_tensor) + 16 ) {
1046
+ return false ;
1047
+ }
1048
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1049
+ uint64_t offset;
1050
+ memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1051
+ const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
1052
+ std::vector<uint8_t > cached_file;
1053
+ if (!get_cached_file (*hash, cached_file)) {
1054
+ response.result = 0 ;
1055
+ return true ;
1056
+ }
1057
+ size_t size = cached_file.size ();
1058
+ struct ggml_init_params params {
1059
+ /* .mem_size =*/ ggml_tensor_overhead(),
1060
+ /* .mem_buffer =*/ NULL ,
1061
+ /* .no_alloc =*/ true ,
1062
+ };
1063
+ struct ggml_context * ctx = ggml_init (params);
1064
+ ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor);
1065
+ if (tensor == nullptr ) {
1066
+ GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
1067
+ ggml_free (ctx);
1068
+ return false ;
1069
+ }
1070
+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1071
+
1072
+ // sanitize tensor->data
1073
+ {
1074
+ const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
1075
+ const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
1076
+
1077
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1078
+ GGML_ABORT (" [%s] tensor->data out of bounds\n " , __func__);
1079
+ }
1080
+ }
1081
+ ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1082
+ response.result = 1 ;
1083
+ ggml_free (ctx);
1084
+ return true ;
1085
+ }
1086
+
968
1087
bool rpc_server::init_tensor (const rpc_msg_init_tensor_req & request) {
969
1088
struct ggml_init_params params {
970
1089
/* .mem_size =*/ ggml_tensor_overhead(),
@@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() {
1148
1267
}
1149
1268
}
1150
1269
1151
- static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1152
- rpc_server server (backend);
1270
+ static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
1271
+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1272
+ rpc_server server (backend, cache_dir);
1153
1273
while (true ) {
1154
1274
uint8_t cmd;
1155
1275
if (!recv_data (sockfd, &cmd, 1 )) {
@@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1260
1380
}
1261
1381
break ;
1262
1382
}
1383
+ case RPC_CMD_SET_TENSOR_HASH: {
1384
+ std::vector<uint8_t > input;
1385
+ if (!recv_msg (sockfd, input)) {
1386
+ return ;
1387
+ }
1388
+ rpc_msg_set_tensor_hash_rsp response;
1389
+ if (!server.set_tensor_hash (input, response)) {
1390
+ return ;
1391
+ }
1392
+ if (!send_msg (sockfd, &response, sizeof (response))) {
1393
+ return ;
1394
+ }
1395
+ break ;
1396
+ }
1263
1397
case RPC_CMD_INIT_TENSOR: {
1264
1398
rpc_msg_init_tensor_req request;
1265
1399
if (!recv_msg (sockfd, &request,sizeof (request))) {
@@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1335
1469
}
1336
1470
}
1337
1471
1338
- void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1472
+ void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint,
1473
+ const char * cache_dir,
1474
+ size_t free_mem, size_t total_mem) {
1339
1475
std::string host;
1340
1476
int port;
1341
1477
if (!parse_endpoint (endpoint, host, port)) {
@@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
1364
1500
}
1365
1501
printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
1366
1502
fflush (stdout);
1367
- rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1503
+ rpc_serve_client (backend, cache_dir, client_socket->fd , free_mem, total_mem);
1368
1504
printf (" Client connection closed\n " );
1369
1505
fflush (stdout);
1370
1506
}
0 commit comments