Skip to content

Commit c0bd08a

Browse files
authored
remove meta Tensor warning when dispatch meta operator (#2246)
1 parent d684196 commit c0bd08a

File tree

16 files changed

+107
-29
lines changed

16 files changed

+107
-29
lines changed
68.5 KB
Loading

examples/transformers/inference/deepseek-ocr/run_dpsk_ocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import mindnlp
33
from transformers import AutoModel, AutoTokenizer
44

5-
model_name = 'lvyufeng/DeepSeek-OCR-Community-Latest'
5+
model_name = 'lvyufeng/DeepSeek-OCR'
66

77

88
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

mindnlp/transformers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .masking_utils import create_causal_mask, create_sliding_window_causal_mask, create_masks_for_generate
1111
from .modeling_utils import construct_pipeline_parallel_model, _load_pretrained_model_wrapper, \
1212
_get_resolved_checkpoint_files_wrapper
13-
from .cache_utils import dynamic_layer_update
13+
from .cache_utils import dynamic_layer_update, dynamic_sliding_window_layer_update
1414
from .tokenization_utils import apply_chat_template_wrapper
1515
from .trainer import training_step
1616
from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper, patch_wrappers
@@ -70,5 +70,6 @@ def empty_fn(*args, **kwargs):
7070
transformers.trainer.Trainer.training_step = training_step
7171

7272
transformers.cache_utils.DynamicLayer.update = dynamic_layer_update
73+
transformers.cache_utils.DynamicSlidingWindowLayer.update = dynamic_sliding_window_layer_update
7374
# add mindnlp.transformers modules/attrs to lazymodule
7475
# setattr(sys.modules[__name__], 'test_ms_model', test_ms_model)

mindnlp/transformers/cache_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,39 @@ def dynamic_layer_update(
2727
self.keys = mindtorch.cat([self.keys, key_states], dim=-2)
2828
self.values = mindtorch.cat([self.values, value_states], dim=-2)
2929
return self.keys, self.values
30+
31+
def dynamic_sliding_window_layer_update(
32+
self,
33+
key_states: mindtorch.Tensor,
34+
value_states: mindtorch.Tensor,
35+
cache_kwargs: Optional[dict[str, Any]] = None,
36+
) -> tuple[mindtorch.Tensor, mindtorch.Tensor]:
37+
"""
38+
Update the key and value caches in-place, and return the necessary keys and value states.
39+
40+
Args:
41+
key_states (`torch.Tensor`): The new key states to cache.
42+
value_states (`torch.Tensor`): The new value states to cache.
43+
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
44+
45+
Returns:
46+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
47+
"""
48+
# Lazy initialization
49+
if not self.is_initialized:
50+
self.lazy_initialization(key_states)
51+
full_key_states = key_states
52+
full_value_states = value_states
53+
else:
54+
# Compute the full states
55+
full_key_states = mindtorch.cat([self.keys, key_states], dim=-2)
56+
full_value_states = mindtorch.cat([self.values, value_states], dim=-2)
57+
58+
self.cumulative_length += key_states.shape[-2]
59+
60+
# Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
61+
self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
62+
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
63+
64+
# Return the full states
65+
return full_key_states, full_value_states

mindnlp/transformers/masking_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def eager_mask(
463463
"""
464464
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
465465
_ = kwargs.pop("allow_is_causal_skip", None)
466+
466467
mask = sdpa_mask(
467468
batch_size=batch_size,
468469
cache_position=cache_position,
@@ -785,12 +786,6 @@ def create_sliding_window_causal_mask(
785786
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
786787
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
787788

788-
# If we detected packing format
789-
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
790-
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
791-
allow_is_causal_skip = False
792-
793-
# Allow slight deviations from sliding causal mask
794789
if or_mask_function is not None:
795790
if not _is_torch_greater_or_equal_than_2_6:
796791
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
@@ -802,6 +797,12 @@ def create_sliding_window_causal_mask(
802797
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
803798
allow_is_causal_skip = False
804799

800+
801+
# If we detected packing format
802+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
803+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
804+
allow_is_causal_skip = False
805+
805806
# We now create the mask
806807
causal_mask = mask_interface(
807808
batch_size=batch_size,

mindnlp/utils/safetensors_patch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def get(self, slice=None):
9898
if not SUPPORT_BF16 and self.info["dtype"] == "BF16":
9999
array = array.astype(np.float16)
100100
tensor = mindtorch.from_numpy(array)
101-
tensor._ptr = array.ctypes.data
102101
return tensor
103102

104103
@property

mindtorch/_apis/cpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def pad_v3(input, new_pad, mode, value=None, contiguous=True):
156156
def cumsum(self, dim, dtype):
157157
if self.shape[dim] == 0:
158158
return mindspore.tensor([], dtype=self.dtype)
159+
if self.dtype == mindspore.int64:
160+
return cast(legacy.cum_sum(cast(self, mindspore.int32), dim, False, False), mindspore.int64)
159161
return legacy.cum_sum(self, dim, False, False)
160162

161163
def reduce_any(input, axis, keepdims):
@@ -1228,7 +1230,7 @@ def search_sorted(sorted_sequence, values, sorter, dtype, right):
12281230
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)
12291231

12301232
def scatter_nd_update(input, indices, updates):
1231-
return legacy.scatter_nd_update(input, indices, updates, True)
1233+
return legacy.scatter_nd_update(input, indices, cast(updates, input.dtype), True)
12321234

12331235
def triu_indices(row, col, offset, dtype):
12341236
return legacy.triu_indices(row, col, offset, dtype)

mindtorch/_apis/meta.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,21 @@ def inplace_normal(input, *args):
6262
__all__.append('inplace_normal')
6363

6464
def getitem(input, slice):
65-
out = input.numpy()[slice]
65+
out = np.zeros(input.shape)[slice]
6666
out = Tensor_(init='none', shape=out.shape, dtype=input.dtype)
6767
return mindtorch.Tensor(out)
6868

6969
__all__.append('getitem')
7070

71-
def sub(input, other, alpha):
71+
def sub(input, other, alpha=1.0):
7272
if isinstance(input, mindtorch.Tensor):
7373
return input
7474
return other
7575

7676
__all__.append('sub')
7777

7878
def pad_v3(input, pad, mode, value):
79-
out = np.pad(input.numpy(), pad, mode, constant_values=value)
79+
out = np.pad(np.zeros(input.shape), pad, mode, constant_values=value)
8080
out = Tensor_(init='none', shape=out.shape, dtype=input.dtype)
8181
return mindtorch.Tensor(out)
8282

@@ -94,7 +94,7 @@ def cast(input, dtype):
9494
__all__.append('cast')
9595

9696
def index_select(input, dim, index):
97-
out = np.take(input.numpy(), index.numpy(), dim)
97+
out = np.take(np.zeros(input.shape), np.zeros(index.shape, dtype=np.int64), dim)
9898
out = Tensor_(init='none', shape=out.shape, dtype=input.dtype)
9999
return mindtorch.Tensor(out)
100100

@@ -146,6 +146,9 @@ def tril(input, k):
146146
__all__.append('tril')
147147

148148
def reshape(input, shape):
149+
if -1 in shape:
150+
out = np.zeros(input.shape).reshape(shape)
151+
shape = out.shape
149152
out = Tensor_(init='none', shape=tuple(shape), dtype=input.dtype)
150153
return mindtorch.Tensor(out)
151154

@@ -414,4 +417,20 @@ def pad(input, pad, mode='constant', value=None):
414417
raise ValueError('pad size must be 2, 4 or 6')
415418

416419
out = Tensor_(init='none', shape=new_size, dtype=input.dtype)
420+
return mindtorch.Tensor(out)
421+
422+
def setitem(self, slice, value):
423+
return self
424+
425+
def meshgrid(args, lambd):
426+
res = np.meshgrid(*args, indexing=lambd)
427+
outs = ()
428+
for r in res:
429+
out = Tensor_(init='none', shape=r.shape, dtype=args[0].dtype)
430+
out = mindtorch.Tensor(out)
431+
outs += (out,)
432+
return outs
433+
434+
def permute(input, dims):
435+
out = Tensor_(init='none', shape=dims, dtype=input.dtype)
417436
return mindtorch.Tensor(out)

mindtorch/_apis/npu_310b.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,9 @@ def unique_dim(input, sorted, return_inverse, dim):
15531553
return legacy.unique_dim(input, sorted, return_inverse, dim)
15541554

15551555
def inplace_add(input, other, alpha):
1556+
if isinstance(other, numbers.Number):
1557+
other = mindspore.Tensor(other, dtype=input.dtype)
1558+
15561559
if ENABLE_PYBOOST:
15571560
return pyboost.inplace_add_ext_op(input, other, alpha)
15581561
return legacy.inplace_add(input, other)

mindtorch/_apis/npu_910a.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ def sqrt(input):
11241124
return legacy.sqrt(input)
11251125

11261126
def masked_scatter(input, mask, value):
1127-
return legacy.masked_scatter(input, mask, value)
1127+
return legacy.masked_scatter(input, mask, cast(value, input.dtype))
11281128

11291129
def neg(input):
11301130
if ENABLE_PYBOOST:
@@ -1532,6 +1532,8 @@ def unique_dim(input, sorted, return_inverse, dim):
15321532
return legacy.unique_dim(input, sorted, return_inverse, dim)
15331533

15341534
def inplace_add(input, other, alpha):
1535+
if isinstance(other, numbers.Number):
1536+
other = mindspore.Tensor(other, dtype=input.dtype)
15351537
if ENABLE_PYBOOST:
15361538
return pyboost.inplace_add_ext_op(input, other, alpha)
15371539
return legacy.inplace_add(input, other)
@@ -1788,6 +1790,11 @@ def log2(input):
17881790
return legacy.log2(input)
17891791

17901792
def bucketize(input, boundaries, right=False):
1793+
if isinstance(boundaries, mindtorch.Tensor):
1794+
boundaries = boundaries.tolist()
1795+
1796+
if not boundaries:
1797+
return zeros_like(input)
17911798
epsilon_ = 0. if right else 1.e-6
17921799
boundaries = [boundary + epsilon_ for boundary in boundaries]
17931800
return legacy.bucketize(input, boundaries)
@@ -2095,13 +2102,20 @@ def _process_dim_in_multi_dim_index(prev_result, orig_tensor, index, dim, indexe
20952102
result = _do_select(prev_result, dim, index.item(), dim_index, prev_shape)
20962103
del prev_shape[dim]
20972104
return result, dim, remain_indexes, prev_shape
2105+
20982106
# process index with Tensor bool type
20992107
result = expand_dims(prev_result, dim)
21002108
index_for_bool = tensor_1d if index else empty_tensor_1d
21012109
_record_tensor_index(index_for_bool, remain_indexes, dim)
21022110
prev_shape.insert(dim, 1)
21032111
dim += 1
21042112
return result, dim, remain_indexes, prev_shape
2113+
2114+
if index.dtype == mindtorch.bool and prev_result.ndim == 1:
2115+
result = masked_select(prev_result, index)
2116+
dim += 1
2117+
return result, dim, remain_indexes, prev_shape
2118+
21052119
_record_tensor_index(index, remain_indexes, dim)
21062120
dim += 1
21072121
return result, dim, remain_indexes, prev_shape

0 commit comments

Comments
 (0)