Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong Device for LagLlamaEstimator().create_predictor #127

Open
e-hossam96 opened this issue Feb 17, 2025 · 4 comments
Open

Wrong Device for LagLlamaEstimator().create_predictor #127

e-hossam96 opened this issue Feb 17, 2025 · 4 comments

Comments

@e-hossam96
Copy link

e-hossam96 commented Feb 17, 2025

Hi there,

Thanks for providing Lag-Llama. It is a wonderful model.

Kindly fix the predictor device to be the same as the estimator.

def create_predictor(
    self,
    transformation: Transformation,
    module,
) -> PyTorchPredictor:
    prediction_splitter = self._create_instance_splitter(module, "test")
    if self.time_feat:
        return PyTorchPredictor(
            input_transform=transformation + prediction_splitter,
            input_names=PREDICTION_INPUT_NAMES
            + ["past_time_feat", "future_time_feat"],
            prediction_net=module,
            batch_size=self.batch_size,
            prediction_length=self.prediction_length,
            # device="cuda" if torch.cuda.is_available() else "cpu",
            device=self.device.type,
        )
    else:
        return PyTorchPredictor(
            input_transform=transformation + prediction_splitter,
            input_names=PREDICTION_INPUT_NAMES,
            prediction_net=module,
            batch_size=self.batch_size,
            prediction_length=self.prediction_length,
            # device="cuda" if torch.cuda.is_available() else "cpu",
            device=self.device.type,
        )
@FBielicki
Copy link

done here: #125

should also be self.device instead of self.device.type

@e-hossam96
Copy link
Author

I believe there might be some errors when serializing the model if we just use self.device.

predictor.serialize(path=pathlib.Path(".cache/lag-llama-plus"))

This gives me that torch.device is not serializable.

P.S. To the best of my knowledge about lag-llama codes.

@FBielicki
Copy link

Interesting. I tested the code with both options and in both cases the code worked (no serialization).

I thought because .type returns a String that there will be an error.

Maybe then self.device.type will be the better option.

@e-hossam96
Copy link
Author

Thanks 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants