|
4 | 4 | import logging
|
5 | 5 | import uuid
|
6 | 6 | import warnings
|
| 7 | +from copy import deepcopy |
7 | 8 |
|
8 | 9 | import numpy as np
|
9 | 10 | import pandas as pd
|
@@ -74,25 +75,44 @@ def _get_context_metadata(self):
|
74 | 75 | if self._sequence_key:
|
75 | 76 | context_columns += self._sequence_key
|
76 | 77 |
|
77 |
| - for column in context_columns: |
78 |
| - context_columns_dict[column] = self.metadata.columns[column] |
79 |
| - # Context datetime SDTypes for PAR have already been converted to float timestamp |
80 |
| - if context_columns_dict[column]['sdtype'] == 'datetime': |
81 |
| - context_columns_dict[column] = {'sdtype': 'numerical'} |
82 |
| - |
83 | 78 | for column, column_metadata in self._extra_context_columns.items():
|
84 | 79 | context_columns_dict[column] = column_metadata
|
85 | 80 |
|
| 81 | + for column in context_columns: |
| 82 | + context_columns_dict[column] = self.metadata.columns[column] |
| 83 | + |
| 84 | + context_columns_dict = self._update_context_column_dict(context_columns_dict) |
86 | 85 | context_metadata_dict = {'columns': context_columns_dict}
|
87 | 86 | return SingleTableMetadata.load_from_dict(context_metadata_dict)
|
88 | 87 |
|
89 |
| - def _get_context_datetime_columns(self): |
90 |
| - datetime_columns = [] |
| 88 | + def _update_context_column_dict(self, context_columns_dict): |
| 89 | + """Update context column dictionary based on available transformers. |
| 90 | +
|
| 91 | + Args: |
| 92 | + context_columns_dict (dict): |
| 93 | + Dictionary of context columns. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + dict: |
| 97 | + Updated context column metadata. |
| 98 | + """ |
| 99 | + default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype) |
91 | 100 | for column in self.context_columns:
|
92 |
| - if self.metadata.columns[column]['sdtype'] == 'datetime': |
93 |
| - datetime_columns.append(column) |
| 101 | + column_metadata = self.metadata.columns[column] |
| 102 | + sdtype = column_metadata['sdtype'] |
| 103 | + if default_transformers_by_sdtype.get(column_metadata['sdtype']): |
| 104 | + context_columns_dict[column] = {'sdtype': 'numerical'} |
94 | 105 |
|
95 |
| - return datetime_columns |
| 106 | + return context_columns_dict |
| 107 | + |
| 108 | + def _get_context_columns_for_processing(self): |
| 109 | + columns_to_be_processed = [] |
| 110 | + default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype) |
| 111 | + for column in self.context_columns: |
| 112 | + if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']): |
| 113 | + columns_to_be_processed.append(column) |
| 114 | + |
| 115 | + return columns_to_be_processed |
96 | 116 |
|
97 | 117 | def __init__(
|
98 | 118 | self,
|
@@ -545,28 +565,30 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
|
545 | 565 | set(context_columns.columns), set(self._context_synthesizer._model.columns)
|
546 | 566 | )
|
547 | 567 | )
|
548 |
| - context_columns = self._process_datetime_columns_in_context_columns(context_columns) |
| 568 | + context_columns = self._process_context_columns(context_columns) |
549 | 569 |
|
550 | 570 | condition_columns = context_columns[condition_columns].to_dict('records')
|
551 | 571 | synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
|
552 | 572 | context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
|
553 | 573 | context.update(context_columns)
|
554 | 574 | return self._sample(context, sequence_length)
|
555 | 575 |
|
556 |
| - def _process_datetime_columns_in_context_columns(self, context_columns): |
557 |
| - """Process datetime columns by transforming them using the data processor. |
| 576 | + def _process_context_columns(self, context_columns): |
| 577 | + """Process context columns by applying appropriate transformations. |
558 | 578 |
|
559 | 579 | Args:
|
560 | 580 | context_columns (pandas.DataFrame):
|
561 |
| - Context values containing potential datetime columns. |
| 581 | + Context values containing potential columns for transformation. |
562 | 582 |
|
563 | 583 | Returns:
|
564 | 584 | context_columns (pandas.DataFrame):
|
565 |
| - Updated context columns with transformed datetime values. |
| 585 | + Updated context columns with transformed values. |
566 | 586 | """
|
567 |
| - datetime_columns = self._get_context_datetime_columns() |
568 |
| - if datetime_columns: |
569 |
| - transformed = self._data_processor.transform(context_columns[datetime_columns]) |
570 |
| - context_columns[datetime_columns] = transformed[datetime_columns] |
| 587 | + columns_to_be_processed = self._get_context_columns_for_processing() |
| 588 | + |
| 589 | + if columns_to_be_processed: |
| 590 | + context_columns[columns_to_be_processed] = self._data_processor.transform( |
| 591 | + context_columns[columns_to_be_processed] |
| 592 | + ) |
571 | 593 |
|
572 | 594 | return context_columns
|
0 commit comments