Skip to content

Commit 3038d5e

Browse files
authored
If input field was on device than copy it again after loading (#11)
When the field was already on device on a call to load_iteration is was not copied again.
1 parent b5d216b commit 3038d5e

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/core/serialization.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class serialization : private boost::noncopyable {
8484
type_erased_field_view< T > field,
8585
const ser::Savepoint &savepoint,
8686
bool alsoPrevious = false) {
87+
88+
bool field_was_on_device = !field.is_on_host();
89+
8790
// Make sure data is on the Host
8891
field.update_host();
8992

@@ -122,6 +125,9 @@ class serialization : private boost::noncopyable {
122125

123126
serializer_->ReadField(
124127
name, savepoint, static_cast< void * >(field.data()), iStride, jStride, kStride, 0, alsoPrevious);
128+
129+
if (field_was_on_device)
130+
field.update_device();
125131
}
126132

127133
/** @} */

src/core/type_erased_field.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ namespace internal {
205205
* Update host pointer (calls d2h_update)
206206
*/
207207
virtual void update_host() noexcept = 0;
208+
209+
virtual bool is_on_host() const noexcept = 0;
208210
};
209211

210212
template < typename FieldType, typename T >
@@ -258,6 +260,8 @@ namespace internal {
258260

259261
virtual void update_host() noexcept override { field_helper::d2h_update(field_); }
260262

263+
virtual bool is_on_host() const noexcept override { return field_.is_on_host(); }
264+
261265
private:
262266
FieldType &field_;
263267
};
@@ -330,6 +334,8 @@ namespace internal {
330334

331335
virtual void update_host() noexcept override { field_helper::d2h_update(field_); }
332336

337+
virtual bool is_on_host() const noexcept override { return field_.is_on_host(); }
338+
333339
private:
334340
typename FieldType::storage_info_type metaData_;
335341
FieldType field_;
@@ -496,6 +502,8 @@ class type_erased_field_view {
496502
*/
497503
void update_host() const noexcept { base_->update_host(); }
498504

505+
bool is_on_host() const noexcept { return base_->is_on_host(); }
506+
499507
private:
500508
std::shared_ptr< internal::type_erased_field_interface< T > > base_;
501509
};
@@ -639,6 +647,8 @@ class type_erased_field {
639647
*/
640648
void update_host() const noexcept { base_->update_host(); }
641649

650+
bool is_on_host() const noexcept { return base_->is_on_host(); }
651+
642652
/**
643653
* @brief Convert field to view
644654
*/

0 commit comments

Comments
 (0)