-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathFunctions.hpp
89 lines (71 loc) · 4.06 KB
/
Functions.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/*******************************************************
* Copyright (c) 2017, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#pragma once
#include <arrayfire.h>
#include <vector>
namespace af {
namespace autograd {
class Variable;
Variable operator +(const Variable &lhs, const Variable &rhs);
Variable operator *(const Variable &lhs, const Variable &rhs);
Variable operator -(const Variable &lhs, const Variable &rhs);
Variable operator /(const Variable &lhs, const Variable &rhs);
Variable operator >(const Variable &lhs, const Variable &rhs);
Variable operator <(const Variable &lhs, const Variable &rhs);
Variable operator >=(const Variable &lhs, const Variable &rhs);
Variable operator <=(const Variable &lhs, const Variable &rhs);
Variable operator +(const double &lhs, const Variable &rhs);
Variable operator *(const double &lhs, const Variable &rhs);
Variable operator -(const double &lhs, const Variable &rhs);
Variable operator /(const double &lhs, const Variable &rhs);
Variable operator >(const double &lhs, const Variable &rhs);
Variable operator <(const double &lhs, const Variable &rhs);
Variable operator >=(const double &lhs, const Variable &rhs);
Variable operator <=(const double &lhs, const Variable &rhs);
Variable operator +(const Variable &lhs, const double &rhs);
Variable operator *(const Variable &lhs, const double &rhs);
Variable operator -(const Variable &lhs, const double &rhs);
Variable operator /(const Variable &lhs, const double &rhs);
Variable operator >(const Variable &lhs, const double &rhs);
Variable operator <(const Variable &lhs, const double &rhs);
Variable operator >=(const Variable &lhs, const double &rhs);
Variable operator <=(const Variable &lhs, const double &rhs);
Variable operator !(const Variable &input);
Variable negate(const Variable &input);
Variable reciprocal(const Variable &input);
Variable exp(const Variable &input);
Variable log(const Variable &input);
Variable sin(const Variable &input);
Variable cos(const Variable &input);
Variable tanh(const Variable &input);
Variable sigmoid(const Variable &input);
Variable max(const Variable &lhs, const Variable &rhs);
Variable max(const Variable &lhs, const double &rhs);
Variable max(const double &lhs, const Variable &rhs);
Variable min(const Variable &lhs, const Variable &rhs);
Variable min(const Variable &lhs, const double &rhs);
Variable min(const double &lhs, const Variable &rhs);
Variable transpose(const Variable &input);
Variable tileAs(const Variable &input, const Variable &reference);
Variable sumAs(const Variable &input, const Variable &reference);
Variable tile(const Variable &input, const std::vector<int> &repeats);
Variable sum(const Variable &input, const std::vector<int> &axes);
Variable mean(const Variable &input, const std::vector<int> &axes);
Variable matmul(const Variable &lhs, const Variable &rhs);
Variable matmulTN(const Variable &lhs, const Variable &rhs);
Variable matmulNT(const Variable &lhs, const Variable &rhs);
Variable abs(const Variable &input);
Variable flat(const Variable &input);
Variable moddims(const Variable &input, const dim4 &dims);
Variable reorder(const Variable &input, int d0, int d1, int d2, int d3);
Variable unwrap(const Variable &input, int wx, int wy, int sx, int sy, int px, int py);
Variable wrap(const Variable &input, int ox, int oy, int wx, int wy, int sx, int sy, int px, int py);
Variable convolve2(const Variable &input, const Variable &weights, int wx, int wy, int sx, int sy, int px, int py);
}
}