This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathCore.swift
90 lines (76 loc) · 2.75 KB
/
Core.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import _Differentiation
/// A flatten layer.
///
/// A flatten layer flattens the input when applied without affecting the batch size.
@frozen
public struct Flatten<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
public typealias TangentVector = EmptyTangentVector
/// Creates a flatten layer.
public init() {}
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let batchSize = input.shape[0]
let remaining = input.shape[1..<input.rank].contiguousSize
return input.reshaped(to: [batchSize, remaining])
}
}
/// A reshape layer.
@frozen
public struct Reshape<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
public typealias TangentVector = EmptyTangentVector
/// The target shape.
@noDerivative public var shape: Tensor<Int32>
// TF-331 workaround:
@usableFromInline
internal var _nontrivial = Tensor<Float>(0)
/// Creates a reshape layer.
///
/// - Parameter shape: The target shape, represented by a tensor.
public init(shape: Tensor<Int32>) {
self.shape = shape
}
/// Creates a reshape layer.
///
/// - Parameter shape: The target shape.
public init(_ shape: TensorShape) {
self.init(shape: Tensor(shape.dimensions.map(Int32.init)))
}
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.reshaped(toShape: shape)
}
}
/// A layer that encloses a custom differentiable function.
public struct Function<Input: Differentiable, Output: Differentiable>: ParameterlessLayer {
public typealias TangentVector = EmptyTangentVector
public typealias Body = @differentiable (Input) -> Output
@noDerivative public let body: Body
public init(_ body: @escaping Body) {
self.body = body
}
@differentiable
public func callAsFunction(_ input: Input) -> Output {
body(input)
}
}