Skip to content

TypeError: unsupported operand type(s) for /: 'Tensor' and 'NoneType' when use mc-ebc and mean pooling #2828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tiankongdeguiji opened this issue Mar 17, 2025 · 1 comment

Comments

@tiankongdeguiji
Copy link
Contributor

tiankongdeguiji commented Mar 17, 2025

When use mc-ebc and pooling=PoolingType.MEAN, we encounter issue TypeError: unsupported operand type(s) for /: 'Tensor' and 'NoneType'. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=49941 --nnodes=1 --nproc-per-node=2 test_mc_ebc_mean.py,and use the environment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124.

test_mc_ebc_mean.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
    LFU_EvictionPolicy,
    ManagedCollisionCollection,
    MCHManagedCollisionModule,
)
from torchrec.inference.state_dict_transform import state_dict_gather, state_dict_to_device
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.distributed.sharding_plan import get_default_sharders


rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()


large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.MEAN,
    )
    for i in range(large_table_cnt)
]
large_mc_modules = {
    t.name: MCHManagedCollisionModule(
        zch_size=4096,
        device=device, #'meta',
        eviction_interval=10,
        eviction_policy=LFU_EvictionPolicy()
    ) 
    for t in large_tables
}
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=64,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.MEAN,
    )
    for i in range(small_table_cnt)
]
small_mc_modules = {
    t.name: MCHManagedCollisionModule(
        zch_size=64,
        device=device, #'meta',
        eviction_interval=10,
        eviction_policy=LFU_EvictionPolicy()
    ) 
    for t in small_tables
}


def gen_constraints(
    sharding_type: ShardingType = ShardingType.ROW_WISE,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = ManagedCollisionEmbeddingBagCollection(
            EmbeddingBagCollection(tables=large_tables + small_tables, device="meta"),
            ManagedCollisionCollection(
                dict(large_mc_modules, **small_mc_modules),
                large_tables + small_tables,
            )
        )
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb, _ = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.ROW_WISE)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
# sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, get_default_sharders(), dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])

pipeline = TrainPipelineSparseDist(
    sharded_model, optimizer, sharded_model.device, execute_all_batches=True
)


for _ in range(20):
    batch_size = 64
    lengths_large = torch.randint(0, 10, (batch_size * large_table_cnt,))
    lengths_small = torch.randint(0, 10, (batch_size * small_table_cnt,))
    kjt = KeyedJaggedTensor(
        keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
        + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
        values=torch.cat([
            torch.randint(0, 4096, (torch.sum(lengths_large),))
            , torch.randint(0, 64, (torch.sum(lengths_small),))]
        ),
        lengths=torch.cat([lengths_large, lengths_small]),
    ).to(device=device)
    losses = sharded_model.forward(kjt)
    torch.sum(losses, dim=0).backward()
    optimizer.step()

error info:

[rank0]: Traceback (most recent call last):
[rank0]:   File "test_mc_ebc_mean.py", line 164, in <module>
[rank0]:     losses = sharded_model.forward(kjt)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]:     return self._dmp_wrapped_module(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "test_mc_ebc_mean.py", line 120, in forward
[rank0]:     return torch.mean(self.linear(emb.values()))
[rank0]:                                   ^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 451, in __getattr__
[rank0]:     res = LazyAwaitable._wait_async(self)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 409, in _wait_async
[rank0]:     obj._result = obj.wait()
[rank0]:                   ^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 336, in wait
[rank0]:     ret = callback(ret)
[rank0]:           ^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 1732, in _apply_mean_pooling
[rank0]:     keyed_tensor.values() / divisor
[rank0]:     ~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
[rank0]: TypeError: unsupported operand type(s) for /: 'Tensor' and 'NoneType'
@tiankongdeguiji
Copy link
Contributor Author

hi, @sarckk @iamzainhuda @henrylhtsang @PaulZhang12 could you take a look?

I think _create_mean_pooling_divisor should be called in BaseShardedManagedCollisionEmbeddingCollection. we fix it in #2829

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant