forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkernel_cache.h
33 lines (23 loc) · 994 Bytes
/
kernel_cache.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#pragma once
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
#include <torch/csrc/jit/ir/ir.h>
#include <cstdint>
#include <functional>
#include <optional>
namespace torch::jit::fuser {
// A thread-safe cache interface.
// Normalizes the graph by canonicalizing and erasing shape information
TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(
const std::shared_ptr<Graph>& graph);
// Stores the given graph, returning the key used to access it
TORCH_API int64_t store(std::shared_ptr<Graph> graph);
// Given a graph, find a KernelSpec based on it
TORCH_API std::optional<KernelSpec*> lookupGraph(
const std::shared_ptr<Graph>& graph);
// Returns the graph corresponding to the given key (if it exists)
TORCH_API std::optional<KernelSpec*> retrieve(const int64_t key);
// Returns the size of the fusion key -> KernelSpec cache.
// Only used for testing.
TORCH_API int64_t debugNumCachedKernelSpecs();
} // namespace torch::jit::fuser