W3Cschool
恭喜您成為首批注冊(cè)用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
#版權(quán)所有2015 TensorFlow作者.版權(quán)所有.
#
#根據(jù)Apache許可證版本2.0(“許可證”)許可;
#除非符合許可證,否則您不得使用此文件.
#您可以獲得許可證的副本
#
#http://www.apache.org/licenses/LICENSE-2.0
#
#除非適用法律要求或書面同意軟件
根據(jù)許可證分發(fā)的#分發(fā)在“按原樣”基礎(chǔ)上,
#無(wú)明示或暗示的任何種類的保證或條件.
#查看有關(guān)權(quán)限的特定語(yǔ)言的許可證
許可證下的#限制.
# =============================================== =============================
""Helper 函數(shù)用于創(chuàng)建分區(qū)變量.
這是一個(gè)方便的抽象,以分割一個(gè)大變量
可以分配給不同設(shè)備的多個(gè)較小的變量.
可以通過(guò)連接較小的變量來(lái)重構(gòu)完整變量.
使用分區(qū)變量而不是單個(gè)變量大多是一個(gè)
性能選擇.但它也對(duì)以下因素有影響:
1.隨機(jī)初始化,隨機(jī)數(shù)生成器每次調(diào)用一次切片
2.更新,因?yàn)樗鼈兛缙尾⑿邪l(fā)生
一個(gè)關(guān)鍵的設(shè)計(jì)目標(biāo)是允許不同的圖形來(lái)重新分配變量
具有相同的名稱但不同的切片,包括可能沒(méi)有分區(qū).
TODO(touts):如果 initializer 提供種子,則必須更改種子
地每個(gè)切片,也許通過(guò)添加一個(gè),否則每個(gè)切片
切片將使用相同的值.也許這可以通過(guò)傳遞
切片偏移量到初始化器功能.
典型用法:
```python
#使用以下命令創(chuàng)建分區(qū)變量列表:
vs = create_partitioned_variables(
<shape>,<sliceing>,<initializer>,name = <optional-name>)
#將列表作為輸入傳遞給嵌入式并行查找的 embedding_lookup:
y = embedding_lookup(vs,ids,partition_strategy =“div”)
#或者并行獲取變量以加快大量的 matmuls:
z = matmul(x,concat(slice_dim,vs))
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
__all__ = [
"create_partitioned_variables",
"variable_axis_size_partitioner",
"min_max_variable_partitioner",
"fixed_size_partitioner",
]
def variable_axis_size_partitioner(
max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
"""Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
This partitioner will shard a Variable along one axis, attempting to keep
the maximum shard size below `max_shard_bytes`. In practice, this is not
always possible when sharding along only one axis. When this happens,
this axis is sharded as much as possible (i.e., every dimension becomes
a separate shard).
If the partitioner hits the `max_shards` limit, then each shard may end up
larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
limit on the number of shards is enforced.
One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
`64MB`, to keep below the protobuf byte limit.
Args:
max_shard_bytes: The maximum size any given shard is allowed to be.
axis: The axis to partition along. Default: outermost axis.
bytes_per_string_element: If the `Variable` is of type string, this provides
an estimate of how large each scalar in the `Variable` is.
max_shards: The maximum number of shards in int created taking precedence
over `max_shard_bytes`.
Returns:
A partition function usable as the `partitioner` argument to
`variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
Raises:
ValueError: If any of the byte counts are non-positive.
"""
if max_shard_bytes < 1 or bytes_per_string_element < 1:
raise ValueError(
"Both max_shard_bytes and bytes_per_string_element must be positive.")
if max_shards and max_shards < 1:
raise ValueError(
"max_shards must be positive.")
def _partitioner(shape, dtype):
"""Partitioner that partitions shards to have max_shard_bytes total size.
Args:
shape: A `TensorShape`.
dtype: A `DType`.
Returns:
A tuple representing how much to slice each axis in shape.
Raises:
ValueError: If shape is not a fully defined `TensorShape` or dtype is not
a `DType`.
"""
if not isinstance(shape, tensor_shape.TensorShape):
raise ValueError("shape is not a TensorShape: %s" % shape)
if not shape.is_fully_defined():
raise ValueError("shape is not fully defined: %s" % shape)
if not isinstance(dtype, dtypes.DType):
raise ValueError("dtype is not a DType: %s" % dtype)
if dtype.base_dtype == dtypes.string:
element_size = bytes_per_string_element
else:
element_size = dtype.size
partitions = [1] * shape.ndims
bytes_per_slice = 1.0 * (
shape.num_elements() / shape[axis].value) * element_size
# How many slices can we fit on one shard of size at most max_shard_bytes?
# At least one slice is required.
slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
# How many shards do we need for axis given that each shard fits
# slices_per_shard slices from a total of shape[axis].value slices?
axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard))
if max_shards:
axis_shards = min(max_shards, axis_shards)
partitions[axis] = axis_shards
return partitions
return _partitioner
def min_max_variable_partitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
bytes_per_string_element=16):
"""Partitioner to allocate minimum size per slice.
Returns a partitioner that partitions the variable of given shape and dtype
such that each partition has a minimum of `min_slice_size` slice of the
variable. The maximum number of such partitions (upper bound) is given by
`max_partitions`.
Args:
max_partitions: Upper bound on the number of partitions. Defaults to 1.
axis: Axis along which to partition the variable. Defaults to 0.
min_slice_size: Minimum size of the variable slice per partition. Defaults
to 256K.
bytes_per_string_element: If the `Variable` is of type string, this provides
an estimate of how large each scalar in the `Variable` is.
Returns:
A partition function usable as the `partitioner` argument to
`variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
"""
def _partitioner(shape, dtype):
"""Partitioner that partitions list for a variable of given shape and type.
Ex: Consider partitioning a variable of type float32 with
shape=[1024, 1024].
If `max_partitions` >= 16, this function would return
[(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
If `max_partitions` < 16, this function would return
[`max_partitions`, 1].
Args:
shape: Shape of the variable.
dtype: Type of the variable.
Returns:
List of partitions for each axis (currently only one axis can be
partitioned).
Raises:
ValueError: If axis to partition along does not exist for the variable.
"""
if axis >= len(shape):
raise ValueError("Can not partition variable along axis %d when shape is "
"only %s" % (axis, shape))
if dtype.base_dtype == dtypes.string:
bytes_per_element = bytes_per_string_element
else:
bytes_per_element = dtype.size
total_size_bytes = shape.num_elements() * bytes_per_element
partitions = total_size_bytes / min_slice_size
partitions_list = [1] * len(shape)
# We can not partition the variable beyond what its shape or
# `max_partitions` allows.
partitions_list[axis] = max(1, min(shape[axis].value,
max_partitions,
int(math.ceil(partitions))))
return partitions_list
return _partitioner
def fixed_size_partitioner(num_shards, axis=0):
"""Partitioner to specify a fixed number of shards along given axis.
Args:
num_shards: `int`, number of shards to partition variable.
axis: `int`, axis to partition on.
Returns:
A partition function usable as the `partitioner` argument to
`variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
"""
def _partitioner(shape, **unused_args):
partitions_list = [1] * len(shape)
partitions_list[axis] = min(num_shards, shape[axis].value)
return partitions_list
return _partitioner
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
trainable=True, collections=None, name=None, reuse=None):
"""Create a list of partitioned variables according to the given `slicing`.
Currently only one dimension of the full variable can be sliced, and the
full variable can be reconstructed by the concatenation of the returned
list along that dimension.
Args:
shape: List of integers. The shape of the full variable.
slicing: List of integers. How to partition the variable.
Must be of the same length as `shape`. Each value
indicate how many slices to create in the corresponding
dimension. Presently only one of the values can be more than 1;
that is, the variable can only be sliced along one dimension.
For convenience, The requested number of partitions does not have to
divide the corresponding dimension evenly. If it does not, the
shapes of the partitions are incremented by 1 starting from partition
0 until all slack is absorbed. The adjustment rules may change in the
future, but as you can save/restore these variables with different
slicing specifications this should not be a problem.
initializer: A `Tensor` of shape `shape` or a variable initializer
function. If a function, it will be called once for each slice,
passing the shape and data type of the slice as parameters. The
function must return a tensor with the same shape as the slice.
dtype: Type of the variables. Ignored if `initializer` is a `Tensor`.
trainable: If True also add all the variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES`.
collections: List of graph collections keys to add the variables to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
name: Optional name for the full variable. Defaults to
`"PartitionedVariable"` and gets uniquified automatically.
reuse: Boolean or `None`; if `True` and name is set, it would reuse
previously created variables. if `False` it will create new variables.
if `None`, it would inherit the parent scope reuse.
Returns:
A list of Variables corresponding to the slicing.
Raises:
ValueError: If any of the arguments is malformed.
"""
logging.warn(
"create_partitioned_variables is deprecated. Use "
"tf.get_variable with a partitioner set, or "
"tf.get_partitioned_variable_list, instead.")
if len(shape) != len(slicing):
raise ValueError("The 'shape' and 'slicing' of a partitioned Variable "
"must have the length: shape: %s, slicing: %s" %
(shape, slicing))
if len(shape) < 1:
raise ValueError("A partitioned Variable must have rank at least 1: "
"shape: %s" % shape)
# Legacy: we are provided the slicing directly, so just pass it to
# the partitioner.
partitioner = lambda **unused_kwargs: slicing
with variable_scope.variable_scope(
name, "PartitionedVariable", reuse=reuse):
# pylint: disable=protected-access
partitioned_var = variable_scope._get_partitioned_variable(
name=None,
shape=shape,
dtype=dtype,
initializer=initializer,
trainable=trainable,
partitioner=partitioner,
collections=collections)
return list(partitioned_var)
# pylint: enable=protected-access
Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: