@@ -809,16 +809,16 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
809
809
break ;
810
810
}
811
811
ggml_backend_t src_backend = get_buffer_backend (sched , src -> buffer );
812
- if (src_backend != NULL ) {
812
+ // if (src_backend != NULL) {
813
813
int src_prio = sched_backend_prio (sched , src_backend );
814
814
size_t src_size = ggml_nbytes (src );
815
- if (src_prio < cur_prio && src_size >= cur_size ) {
815
+ if (/* src_prio < cur_prio &&*/ src_size >= cur_size ) {
816
816
cur_prio = src_prio ;
817
817
cur_size = src_size ;
818
818
cur_backend = src_backend ;
819
819
SET_CAUSE (node , "1.src%d" , i );
820
820
}
821
- }
821
+ // }
822
822
}
823
823
return cur_backend ;
824
824
}
@@ -1025,9 +1025,21 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
1025
1025
}
1026
1026
ggml_tallocr_t src_allocr = node_allocr (src );
1027
1027
if (src_allocr != node_allocr ) {
1028
- int n_inputs = sched -> splits [cur_split ].n_inputs ++ ;
1029
- GGML_ASSERT (n_inputs < GGML_MAX_SPLIT_INPUTS );
1030
- sched -> splits [cur_split ].inputs [n_inputs ] = (struct ggml_tensor * )src ;
1028
+ // check if the input is already in the split
1029
+ bool found = false;
1030
+ for (int k = 0 ; k < sched -> splits [cur_split ].n_inputs ; k ++ ) {
1031
+ if (sched -> splits [cur_split ].inputs [k ] == src ) {
1032
+ found = true;
1033
+ break ;
1034
+ }
1035
+ }
1036
+
1037
+ if (!found ) {
1038
+ int n_inputs = sched -> splits [cur_split ].n_inputs ++ ;
1039
+ //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, ggml_backend_name(get_allocr_backend(sched, src_allocr)));
1040
+ GGML_ASSERT (n_inputs < GGML_MAX_SPLIT_INPUTS );
1041
+ sched -> splits [cur_split ].inputs [n_inputs ] = (struct ggml_tensor * )src ;
1042
+ }
1031
1043
1032
1044
// create copies
1033
1045
size_t id = hash_id (src );
@@ -1316,6 +1328,7 @@ static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor
1316
1328
1317
1329
struct ggml_tensor * dst = node_copies [id ];
1318
1330
if (dst -> view_src != NULL ) {
1331
+ graph_init_tensor (hash_set , node_copies , node_init , src -> view_src );
1319
1332
ggml_backend_view_init (dst -> view_src -> buffer , dst );
1320
1333
}
1321
1334
else {
0 commit comments