Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Jun 21, 2024
1 parent 69f8079 commit 2b02af6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 47 deletions.
4 changes: 4 additions & 0 deletions tests/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class CodeComparerFactory(CustomDataclassFactory[wr.CodeComparer]):
class CustomChartFactory(CustomDataclassFactory[wr.CustomChart]):
__model__ = wr.CustomChart

@classmethod
def query(cls):
return {"history": {"keys": ["x", "y"], "id": None, "name": None}}


@register_fixture
class LinePlotFactory(CustomDataclassFactory[wr.LinePlot]):
Expand Down
88 changes: 68 additions & 20 deletions wandb_workspaces/reports/v2/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ class PanelGrid(Block):
runsets: LList["Runset"] = Field(default_factory=lambda: [Runset()])
panels: LList["PanelTypes"] = Field(default_factory=list)
active_runset: int = 0
custom_run_colors: Dict[Union[RunId, RunsetGroup], str] = Field(
custom_run_colors: Dict[Union[RunId, RunsetGroup], Union[str, dict]] = Field(
default_factory=dict
)

Expand Down Expand Up @@ -1259,8 +1259,9 @@ def from_model(cls, model: internal.ScatterPlot):

@dataclass(config=dataclass_config, repr=False)
class CustomChart(Panel):
# Custom chart configs should look exactly like they do in the UI. Please check the query carefully!
query: dict = Field(default_factory=dict)
chart_name: str = Field(default_factory=dict)
chart_name: str = Field(default_factory=str)
chart_fields: dict = Field(default_factory=dict)
chart_strings: dict = Field(default_factory=dict)

Expand All @@ -1275,33 +1276,80 @@ def from_table(
)

def to_model(self):
return internal.Vega2(
def dict_to_fields(d):
fields = []
for k, v in d.items():
if k in ("runSets", "limit"):
continue
if isinstance(v, dict) and len(v) > 0:
field = internal.QueryField(
name=k, args=dict_to_fields(v), fields=[]
)
elif isinstance(v, dict) and len(v) == 0 or v is None:
field = internal.QueryField(name=k, fields=[])
else:
field = internal.QueryField(name=k, value=v)
fields.append(field)
return fields

d = self.query
d.setdefault("id", None)
d.setdefault("name", None)

_query = [
internal.QueryField(
name="runSets",
args=[
internal.QueryField(name="runSets", value=r"${runSets}"),
internal.QueryField(name="limit", value=500),
],
fields=dict_to_fields(d),
)
]
user_query = internal.UserQuery(query_fields=_query)

obj = internal.Vega2(
config=internal.Vega2Config(
# user_query=internal.UserQuery(
# query_fields=[
# internal.QueryField(
# args=...,
# fields=...,
# name=...,
# )
# ]
# )
user_query=user_query,
panel_def_id=self.chart_name,
field_settings=self.chart_fields,
string_settings=self.chart_strings,
),
layout=self.layout.to_model(),
id=self._id,
)
obj.ref = self._ref
return obj

@classmethod
def from_model(cls, model: internal.ScatterPlot):
def from_model(cls, model: internal.Vega2):
def fields_to_dict(fields):
d = {}
for field in fields:
if field.args:
for arg in field.args:
d[arg.name] = arg.value

if field.fields:
for subfield in field.fields:
if subfield.args is not None:
d[subfield.name] = fields_to_dict(subfield.args)
else:
d[subfield.name] = subfield.value

d[field.name] = field.value

return d

query = fields_to_dict(model.config.user_query.query_fields)

obj = cls(
# query=model.config.user_query.query_fields,
# chart_name=model.config.panel_def_id,
# chart_fields=model.config.field_settings,
# chart_strings=model.config.string_settings,
query=query,
chart_name=model.config.panel_def_id,
chart_fields=model.config.field_settings,
chart_strings=model.config.string_settings,
layout=Layout.from_model(model.layout),
)

obj._id = model.id
obj._ref = model.ref
return obj


Expand Down
38 changes: 11 additions & 27 deletions wandb_workspaces/reports/v2/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, Tuple, Union
from typing import List as LList

from annotated_types import Annotated, Ge, Le
from annotated_types import Annotated, Ge, Le, Len

try:
from typing import Literal
Expand Down Expand Up @@ -74,7 +74,9 @@ def base_repr(number: int, base: int, padding: int = 0) -> str:
GroupArea = Literal["minmax", "stddev", "stderr", "none", "samples"]
Mark = Literal["solid", "dashed", "dotted", "dotdash", "dotdotdash"]
Timestep = Literal["seconds", "minutes", "hours", "days"]
SmoothingType = Literal["exponentialTimeWeighted", "exponential", "gaussian", "average", "none"]
SmoothingType = Literal[
"exponentialTimeWeighted", "exponential", "gaussian", "average", "none"
]
CodeCompareDiff = Literal["split", "unified"]
Range = Tuple[Optional[float], Optional[float]]
Language = Literal["javascript", "python", "css", "json", "html", "markdown", "yaml"]
Expand Down Expand Up @@ -286,7 +288,7 @@ class PanelGridMetadata(ReportAPIBaseModel):
panel_bank_section_config: PanelBankSectionConfig = Field(
default_factory=PanelBankSectionConfig
)
custom_run_colors: Dict[str, str] = Field(default_factory=dict)
custom_run_colors: Dict[str, Union[str, dict]] = Field(default_factory=dict)
# custom_run_colors: PanelGridCustomRunColors = Field(
# default_factory=PanelGridCustomRunColors
# )
Expand Down Expand Up @@ -769,35 +771,17 @@ class RunComparer(Panel):
config: RunComparerConfig


class QueryFieldsValue(ReportAPIBaseModel):
name: str
value: Any


class QueryFieldsField(ReportAPIBaseModel):
name: str = ""
fields: Optional[LList["QueryFieldsField"]] = None
value: LList[QueryFieldsValue] = Field(default_factory=list)


class QueryField(ReportAPIBaseModel):
args: LList[QueryFieldsValue] = Field(
default_factory=lambda: [
QueryFieldsValue(name="runSets", value="${runSets}"),
QueryFieldsValue(name="limit", value=500),
]
)
fields: LList[QueryFieldsField] = Field(
default_factory=lambda: [
QueryFieldsField(name="id", value=[], fields=None),
QueryFieldsField(name="name", value=[], fields=None),
]
)
name: str = "runSets"
args: Optional[LList["QueryField"]] = None
fields: Optional[LList["QueryField"]] = None
value: Optional[Union[list, float, str]] = None


class UserQuery(ReportAPIBaseModel):
query_fields: LList[QueryField] = Field(default_factory=lambda: [QueryField()])
query_fields: Annotated[LList[QueryField], Len(min_length=1, max_length=1)] = Field(
default_factory=lambda: [QueryField()]
)


class Vega2ConfigTransform(ReportAPIBaseModel):
Expand Down

0 comments on commit 2b02af6

Please sign in to comment.