We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LagLlamaEstimator().create_predictor
Hi there,
Thanks for providing Lag-Llama. It is a wonderful model.
Lag-Llama
Kindly fix the predictor device to be the same as the estimator.
def create_predictor( self, transformation: Transformation, module, ) -> PyTorchPredictor: prediction_splitter = self._create_instance_splitter(module, "test") if self.time_feat: return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES + ["past_time_feat", "future_time_feat"], prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, # device="cuda" if torch.cuda.is_available() else "cpu", device=self.device.type, ) else: return PyTorchPredictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, # device="cuda" if torch.cuda.is_available() else "cpu", device=self.device.type, )
The text was updated successfully, but these errors were encountered:
done here: #125
should also be self.device instead of self.device.type
Sorry, something went wrong.
I believe there might be some errors when serializing the model if we just use self.device.
self.device
predictor.serialize(path=pathlib.Path(".cache/lag-llama-plus"))
This gives me that torch.device is not serializable.
torch.device
P.S. To the best of my knowledge about lag-llama codes.
lag-llama
Interesting. I tested the code with both options and in both cases the code worked (no serialization).
I thought because .type returns a String that there will be an error.
Maybe then self.device.type will be the better option.
Thanks 😊
No branches or pull requests
Hi there,
Thanks for providing
Lag-Llama
. It is a wonderful model.Kindly fix the predictor device to be the same as the estimator.
The text was updated successfully, but these errors were encountered: