Models¶
ModelBase¶
-
class
texar.tf.models.
ModelBase
(hparams=None)[source]¶ Base class inherited by all model classes.
A model class implements interfaces that are compatible with TF Estimator. In particular,
_build()
implements the model_fn interface; andget_input_fn()
is for theinput_fn
interface.-
_build
(features, labels, params, mode, config=None)[source]¶ Used for the model_fn argument when constructing tf.estimator.Estimator.
-
get_input_fn
(*args, **kwargs)[source]¶ Returns the
input_fn
function that constructs the input data, used in tf.estimator.Estimator.
-
Seq2seqBase¶
-
class
texar.tf.models.
Seq2seqBase
(data_hparams, hparams=None)[source]¶ Base class inherited by all seq2seq model classes.
-
_build
(features, labels, params, mode, config=None)[source]¶ Used for the model_fn argument when constructing tf.estimator.Estimator.
-
static
default_hparams
()[source]¶ Returns a dictionary of hyperparameters with default values.
{ "source_embedder": "WordEmbedder", "source_embedder_hparams": {}, "target_embedder": "WordEmbedder", "target_embedder_hparams": {}, "embedder_share": True, "embedder_hparams_share": True, "encoder": "UnidirectionalRNNEncoder", "encoder_hparams": {}, "decoder": "BasicRNNDecoder", "decoder_hparams": {}, "decoding_strategy_train": "train_greedy", "decoding_strategy_infer": "infer_greedy", "beam_search_width": 0, "connector": "MLPTransformConnector", "connector_hparams": {}, "optimization": {}, "name": "seq2seq", }
Here:
- “source_embedder”: str or class or instance
- Word embedder for source text. Can be a class, its name or module path, or a class instance.
- “source_embedder_hparams”: dict
- Hyperparameters for constructing the source embedder. E.g.,
See
default_hparams()
for hyperparameters ofWordEmbedder
. Ignored if “source_embedder” is an instance. - “target_embedder”, “target_embedder_hparams”:
- Same as “source_embedder” and “source_embedder_hparams” but for target text embedder.
- “embedder_share”: bool
- Whether to share the source and target embedder. If True, source embedder will be used to embed target text.
- “embedder_hparams_share”: bool
- Whether to share the embedder configurations. If True, target embedder will be created with “source_embedder_hparams”. But the two embedders have different set of trainable variables.
- “encoder”, “encoder_hparams”:
- Same as “source_embedder” and “source_embedder_hparams” but for encoder.
- “decoder”, “decoder_hparams”:
- Same as “source_embedder” and “source_embedder_hparams” but for decoder.
- “decoding_strategy_train”: str
- The decoding strategy in training mode. See
_build()
for details. - “decoding_strategy_infer”: str
- The decoding strategy in eval/inference mode.
- “beam_search_width”: int
- Beam width. If > 1, beam search is used in eval/inference mode.
- “connector”, “connector_hparams”:
- The connector class and hyperparameters. A connector transforms an encoder final state to a decoder initial state.
- “optimization”: dict
- Hyperparameters of optimizating the model. See
default_optimization_hparams()
for details. - “name”: str
- Name of the model.
-
get_input_fn
(mode, hparams=None)[source]¶ Creates an input function input_fn that provides input data for the model in an Estimator. See, e.g., tf.estimator.train_and_evaluate.
Parameters: - mode – One of members in tf.estimator.ModeKeys.
- hparams – A dict or an
HParams
instance containing the hyperparameters ofPairedTextData
. Seedefault_hparams()
for the the structure and default values of the hyperparameters.
Returns: An input function that returns a tuple (features, labels) when called. features contains data fields that are related to source text, and labels contains data fields related to target text. See
PairedTextData
for all data fields.
-
BasicSeq2seq¶
-
class
texar.tf.models.
BasicSeq2seq
(data_hparams, hparams=None)[source]¶ The basic seq2seq model (without attention).
Example
model = BasicSeq2seq(data_hparams, model_hparams) exor = tx.run.Executor( model=model, data_hparams=data_hparams, config=run_config) exor.train_and_evaluate( max_train_steps=10000, eval_steps=100)
-
_build
(features, labels, params, mode, config=None)¶ Used for the model_fn argument when constructing tf.estimator.Estimator.
-
static
default_hparams
()[source]¶ Returns a dictionary of hyperparameters with default values.
Same as
default_hparams()
ofSeq2seqBase
.
-
get_input_fn
(mode, hparams=None)¶ Creates an input function input_fn that provides input data for the model in an Estimator. See, e.g., tf.estimator.train_and_evaluate.
Parameters: - mode – One of members in tf.estimator.ModeKeys.
- hparams – A dict or an
HParams
instance containing the hyperparameters ofPairedTextData
. Seedefault_hparams()
for the the structure and default values of the hyperparameters.
Returns: An input function that returns a tuple (features, labels) when called. features contains data fields that are related to source text, and labels contains data fields related to target text. See
PairedTextData
for all data fields.
-
get_loss
(decoder_results, features, labels)¶ Computes the training loss.
-