# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various utilities specific to data processing.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import tarfile
import zipfile
import collections
import numpy as np
from six.moves import urllib
import requests
import tensorflow as tf
from texar.tf.utils import utils_io
# pylint: disable=invalid-name, too-many-branches
__all__ = [
"maybe_download",
"read_words",
"make_vocab",
"count_file_lines"
]
Py3 = sys.version_info[0] == 3
[docs]def maybe_download(urls, path, filenames=None, extract=False):
"""Downloads a set of files.
Args:
urls: A (list of) urls to download files.
path (str): The destination path to save the files.
filenames: A (list of) strings of the file names. If given,
must have the same length with :attr:`urls`. If `None`,
filenames are extracted from :attr:`urls`.
extract (bool): Whether to extract compressed files.
Returns:
A list of paths to the downloaded files.
"""
utils_io.maybe_create_dir(path)
if not isinstance(urls, (list, tuple)):
urls = [urls]
if filenames is not None:
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
if len(urls) != len(filenames):
raise ValueError(
'`filenames` must have the same number of elements as `urls`.')
result = []
for i, url in enumerate(urls):
if filenames is not None:
filename = filenames[i]
elif 'drive.google.com' in url:
filename = _extract_google_drive_file_id(url)
else:
filename = url.split('/')[-1]
# If downloading from GitHub, remove suffix ?raw=True
# from local filename
if filename.endswith("?raw=true"):
filename = filename[:-9]
filepath = os.path.join(path, filename)
result.append(filepath)
if not tf.gfile.Exists(filepath):
if 'drive.google.com' in url:
filepath = _download_from_google_drive(url, filename, path)
else:
filepath = _download(url, filename, path)
if extract:
tf.logging.info('Extract %s', filepath)
if tarfile.is_tarfile(filepath):
tarfile.open(filepath, 'r').extractall(path)
elif zipfile.is_zipfile(filepath):
with zipfile.ZipFile(filepath) as zfile:
zfile.extractall(path)
else:
tf.logging.info("Unknown compression type. Only .tar.gz, "
".tar.bz2, .tar, and .zip are supported")
return result
def _download(url, filename, path):
def _progress(count, block_size, total_size):
percent = float(count * block_size) / float(total_size) * 100.
# pylint: disable=cell-var-from-loop
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(filename, percent))
sys.stdout.flush()
filepath = os.path.join(path, filename)
filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded {} {} bytes.'.format(
filename, statinfo.st_size))
return filepath
def _extract_google_drive_file_id(url):
# id is between `/d/` and '/'
url_suffix = url[url.find('/d/') + 3:]
file_id = url_suffix[:url_suffix.find('/')]
return file_id
def _download_from_google_drive(url, filename, path):
"""Adapted from `https://github.com/saurabhshri/gdrive-downloader`
"""
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
file_id = _extract_google_drive_file_id(url)
gurl = "https://docs.google.com/uc?export=download"
sess = requests.Session()
response = sess.get(gurl, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = sess.get(gurl, params=params, stream=True)
filepath = os.path.join(path, filename)
CHUNK_SIZE = 32768
with tf.gfile.GFile(filepath, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk:
f.write(chunk)
print('Successfully downloaded {}.'.format(filename))
return filepath
[docs]def read_words(filename, newline_token=None):
"""Reads word from a file.
Args:
filename (str): Path to the file.
newline_token (str, optional): The token to replace the original newline
token "\\\\n". For example,
`newline_token=tx.data.SpecialTokens.EOS`.
If `None`, no replacement is performed.
Returns:
A list of words.
"""
with tf.gfile.GFile(filename, "r") as f:
if Py3:
if newline_token is None:
return f.read().split()
else:
return f.read().replace("\n", newline_token).split()
else:
if newline_token is None:
return f.read().decode("utf-8").split()
else:
return (f.read().decode("utf-8")
.replace("\n", newline_token).split())
[docs]def make_vocab(filenames, max_vocab_size=-1, newline_token=None,
return_type="list", return_count=False):
"""Builds vocab of the files.
Args:
filenames (str): A (list of) files.
max_vocab_size (int): Maximum size of the vocabulary. Low frequency
words that exceeding the limit will be discarded.
Set to `-1` (default) if no truncation is wanted.
newline_token (str, optional): The token to replace the original newline
token "\\\\n". For example,
`newline_token=tx.data.SpecialTokens.EOS`.
If `None`, no replacement is performed.
return_type (str): Either "list" or "dict". If "list" (default), this
function returns a list of words sorted by frequency. If "dict",
this function returns a dict mapping words to their index sorted
by frequency.
return_count (bool): Whether to return word counts. If `True` and
:attr:`return_type` is "dict", then a count dict is returned, which
is a mapping from words to their frequency.
Returns:
- If :attr:`return_count` is False, returns a list or dict containing \
the vocabulary words.
- If :attr:`return_count` if True, returns a pair of list or dict \
`(a, b)`, where `a` is a list or dict containing the vocabulary \
words, `b` is a list of dict containing the word counts.
"""
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
words = []
for fn in filenames:
words += read_words(fn, newline_token=newline_token)
counter = collections.Counter(words)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, counts = list(zip(*count_pairs))
if max_vocab_size >= 0:
words = words[:max_vocab_size]
counts = counts[:max_vocab_size]
if return_type == "list":
if not return_count:
return words
else:
return words, counts
elif return_type == "dict":
word_to_id = dict(zip(words, range(len(words))))
if not return_count:
return word_to_id
else:
word_to_count = dict(zip(words, counts))
return word_to_id, word_to_count
else:
raise ValueError("Unknown return_type: {}".format(return_type))
[docs]def count_file_lines(filenames):
"""Counts the number of lines in the file(s).
"""
def _count_lines(fn):
with open(fn, "rb") as f:
i = -1
for i, _ in enumerate(f):
pass
return i + 1
if not isinstance(filenames, (list, tuple)):
filenames = [filenames]
num_lines = np.sum([_count_lines(fn) for fn in filenames])
return num_lines