Skip to content

Commit 15e5404

Browse files
committed
More stream gymnastics
1 parent a5132d0 commit 15e5404

File tree

7 files changed

+81
-14
lines changed

7 files changed

+81
-14
lines changed

exllamav2/compat.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import torch
33
import itertools
4+
from exllamav2.device import get_device_stream
45

56
# Emulate pairwise on Python <3.10
67

@@ -63,6 +64,8 @@ def safe_move_tensor(
6364
# Accept torch.device, string or int
6465

6566
device = torch.device(device)
67+
from_index = tensor.device.index
68+
to_index = device.index
6669

6770
# No move
6871

@@ -71,15 +74,68 @@ def safe_move_tensor(
7174

7275
# Copies to/from system RAM are always fine
7376

74-
if tensor.device.type == "cpu" or device.type == "cpu":
75-
return tensor.to(device, non_blocking = non_blocking)
77+
if tensor.device.type == "cpu":
78+
stream = get_device_stream(to_index)
79+
if stream is not None:
80+
with torch.cuda.stream(stream):
81+
r = tensor.to(device, non_blocking = True)
82+
torch.cuda.synchronize(to_index)
83+
return r
84+
else:
85+
return tensor.to(device, non_blocking = non_blocking)
86+
87+
if device.type == "cpu":
88+
stream = get_device_stream(from_index)
89+
if stream is not None:
90+
with torch.cuda.stream(stream):
91+
r = tensor.to(device, non_blocking = True)
92+
torch.cuda.synchronize(from_index)
93+
return r
94+
else:
95+
return tensor.to(device, non_blocking = non_blocking)
7696

7797
# Source and dest are distinct CUDA devices
7898
# Test tensor.to (once) and if it seems to be working, let Torch decide
7999

80100
if test_gpu_peer_copy(tensor.device, device):
81-
return tensor.to(device, non_blocking = non_blocking)
101+
from_stream = get_device_stream(from_index)
102+
to_stream = get_device_stream(to_index)
103+
104+
if from_stream is not None and to_stream is not None:
105+
with torch.cuda.stream(from_stream):
106+
with torch.cuda.stream(to_stream):
107+
r = tensor.to(device, non_blocking = True)
108+
elif from_stream is not None:
109+
with torch.cuda.stream(from_stream):
110+
r = tensor.to(device, non_blocking = True)
111+
elif to_stream is not None:
112+
with torch.cuda.stream(to_stream):
113+
r = tensor.to(device, non_blocking = True)
114+
else:
115+
r = tensor.to(device, non_blocking = True)
116+
117+
if not non_blocking:
118+
torch.cuda.synchronize(to_index)
119+
return r
82120

83121
# Force move tensor via CPU
84122

85-
return tensor.cpu().to(device)
123+
from_stream = get_device_stream(from_index)
124+
to_stream = get_device_stream(to_index)
125+
126+
if from_stream is not None:
127+
with torch.cuda.stream(from_stream):
128+
tensor_cpu = tensor.to("cpu", non_blocking = True)
129+
torch.cuda.synchronize(from_index)
130+
else:
131+
tensor_cpu = tensor.cpu()
132+
133+
if to_stream is not None:
134+
with torch.cuda.stream(to_stream):
135+
r = tensor_cpu.to(device, non_blocking = True)
136+
torch.cuda.synchronize(to_index)
137+
return r
138+
else:
139+
return tensor_cpu.to(device)
140+
141+

exllamav2/device.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,15 @@ def set_device_streams():
2121
global global_streams
2222
for(k, v) in global_streams.items():
2323
with torch.cuda.device(torch.device(k)):
24+
torch.cuda.set_device(torch.device(k))
2425
torch.cuda.set_stream(v)
2526

2627

28+
def get_device_stream(index: int):
29+
global global_streams
30+
return global_streams.get(index)
31+
32+
2733
class ExLlamaV2DeviceContext:
2834

2935
model: ExLlamaV2
@@ -56,7 +62,8 @@ def __init__(
5662
# Create streams (only one per device)
5763

5864
if device_idx not in global_streams:
59-
global_streams[device_idx] = torch.cuda.Stream(torch.device(device_idx), -100)
65+
s = torch.cuda.Stream(torch.device(device_idx), -100)
66+
global_streams[device_idx] = s
6067

6168
self.stream = global_streams[device_idx]
6269

exllamav2/exllamav2_ext/cuda/q_matrix.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ QMatrix::QMatrix
192192
gridDim.x = DIVIDE(width, THREADS_X);
193193
gridDim.y = 1;
194194

195-
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
195+
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
196196
}
197197

198198
QMatrix::~QMatrix()

exllamav2/exllamav2_ext/ext_qmatrix.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ uintptr_t make_q_matrix
8080
TORCH_CHECK(temp_dq.size(0) >= dq_req, "Insufficient size of temp_dq buffer")
8181
}
8282

83+
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
8384
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
85+
8486
QMatrix* m = new QMatrix
8587
(
8688
stream,
@@ -151,6 +153,7 @@ uintptr_t make_q_matrix_split
151153
TORCH_CHECK(false, "Tensor split not implemented for GPTQ matrices");
152154
}
153155

156+
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
154157
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
155158

156159
QMatrix* m = new QMatrix

exllamav2/exllamav2_ext/ext_stloader.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void stloader_read
3535
TORCH_CHECK(load_buffer, "Can't allocate buffer for tensor");
3636
cuda_buffer = (uint8_t*) target.data_ptr();
3737
cudaSetDevice(device.value().index());
38-
stream = at::cuda::getCurrentCUDAStream().stream();
38+
stream = at::cuda::getCurrentCUDAStream(device.value().index()).stream();
3939
}
4040

4141
// Synchronization

exllamav2/ext.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import platform
77
import threading
88
from exllamav2.util import get_basic_progress
9+
from exllamav2.compat import safe_move_tensor
910

1011
extension_name = "exllamav2_ext"
1112
verbose = False # Print wall of text when compiling
@@ -315,8 +316,8 @@ def make_group_map_py(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor:
315316
return torch.tensor(group_map, dtype = torch.short, device = q_groups.device)
316317

317318
def make_group_map(q_groups: torch.Tensor, num_qrows: int) -> torch.Tensor:
318-
group_map = ext_c.make_group_map(q_groups.cpu(), num_qrows).to(q_groups.device)
319-
return group_map
319+
group_map = ext_c.make_group_map(q_groups.cpu(), num_qrows)
320+
return safe_move_tensor(group_map, q_groups.device)
320321

321322

322323
# Create Q matrix

exllamav2/linear.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,16 +560,16 @@ def tp_split(self, broadcast_type: int, dim = None):
560560

561561
w = {
562562
"q_scale": safe_move_tensor(self.q_tensors["q_scale"][:, a // 8:b // 8], idx).contiguous(),
563-
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx).contiguous(),
564-
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx).contiguous(),
565-
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx).contiguous(),
563+
"q_scale_max": safe_move_tensor(self.q_tensors["q_scale_max"], idx),
564+
"q_group_map": safe_move_tensor(self.q_tensors["q_group_map"], idx),
565+
"q_groups": safe_move_tensor(self.q_tensors["q_groups"], idx),
566566
"q_weight": safe_move_tensor(self.q_tensors["q_weight"][:, a:b], idx).contiguous()
567567
}
568568

569569
if "q_perm" in self.q_tensors:
570570
w.update({
571-
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx).contiguous(),
572-
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx).contiguous(),
571+
"q_perm": safe_move_tensor(self.q_tensors["q_perm"], idx),
572+
"q_invperm": safe_move_tensor(self.q_tensors["q_invperm"], idx),
573573
})
574574

575575
if "bias" in self.q_tensors:

0 commit comments

Comments
 (0)