14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111 | class NoteableNotebookTask(NotebookTask):
def __init__(
self,
name: str,
notebook_path: str,
render_deck: bool = False,
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
outputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
**kwargs,
):
# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
# to run pre- and post- execute functions using the corresponding task plugin.
# We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
self._config_task_instance = plugin_class(
task_config=task_config, task_function=_dummy_task_func, **kwargs
)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"{PAPERMILL_TASK_PREFIX}.{name}"
task_type = f"{self._config_task_instance.task_type}"
task_type_version = self._config_task_instance.task_type_version
self._notebook_path = notebook_path
self._render_deck = render_deck
if outputs:
outputs.update(
{
self._IMPLICIT_OP_NOTEBOOK: self._IMPLICIT_OP_NOTEBOOK_TYPE,
self._IMPLICIT_RENDERED_NOTEBOOK: self._IMPLICIT_RENDERED_NOTEBOOK_TYPE,
}
)
# avoid call to NotebookTask.__init__ which does things we don't want to do.
# instead call its parent class's __init__ directly.
super(PythonInstanceTask, self).__init__(
name,
task_config,
task_type=task_type,
task_type_version=task_type_version,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)
@property
def output_notebook_path(self) -> str:
# ensure the output path is on the local filesystem
return parse_noteable_file_id(super().output_notebook_path)
@property
def rendered_output_path(self) -> str:
# ensure the output path is on the local filesystem
return parse_noteable_file_id(super().rendered_output_path)
def execute(self, **kwargs) -> Any:
"""
TODO: Figure out how to share FlyteContext ExecutionParameters with the notebook kernel (as notebook kernel
is executed in a separate python process)
For Spark, the notebooks today need to use the new_session or just getOrCreate session and get a handle to the
singleton
"""
logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.")
# Execute Notebook via Papermill.
pm.execute_notebook(
self._notebook_path,
self.output_notebook_path,
parameters=kwargs,
engine_name="noteable", # changed from upstream
) # type: ignore
outputs = self.extract_outputs(self.output_notebook_path)
self.render_nb_html(self.output_notebook_path, self.rendered_output_path)
m = {}
if outputs:
m = outputs.literals
output_list = []
for k, type_v in self.python_interface.outputs.items():
if k == self._IMPLICIT_OP_NOTEBOOK:
output_list.append(self.output_notebook_path)
elif k == self._IMPLICIT_RENDERED_NOTEBOOK:
output_list.append(self.rendered_output_path)
elif k in m:
v = TypeEngine.to_python_value(
ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v
)
output_list.append(v)
else:
raise RuntimeError(
f"Expected output {k} of type {type_v} not found in the notebook outputs"
)
return tuple(output_list)
|