添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
爱跑步的钥匙  ·  MybatisPlus学习笔记 | ...·  1 月前    · 
大鼻子的手术刀  ·  ERROR! SSH Error: ...·  1 月前    · 
坏坏的红金鱼  ·  TestNG Tutorials 50: ...·  3 周前    · 
英勇无比的莴苣  ·  Better error message ...·  3 周前    · 
玩篮球的茴香  ·  [CMake教程] ...·  2 周前    · 
绅士的酱牛肉  ·  Microsoft moves IT ...·  6 月前    · 
逆袭的黑框眼镜  ·  Chrome's Webdriver ...·  8 月前    · 
跑龙套的牛肉面  ·  搞权色交易 ...·  8 月前    · 
from . import shallownlp, textutils from .data import texts_from_array, texts_from_csv, texts_from_df, texts_from_folder from .eda import get_topic_model from .models import ( print_text_classifiers, print_text_regression_models, text_classifier, text_regression_model, from .ner.data import ( entities_from_array, entities_from_conll2003, entities_from_df, entities_from_gmb, entities_from_txt, from .ner.models import print_sequence_taggers, sequence_tagger from .preprocessor import Transformer, TransformerEmbedding from .qa import AnswerExtractor, SimpleQA from .summarization import TransformerSummarizer from .textextractor import TextExtractor from .textutils import extract_filenames, filter_by_id, load_text_files from .translation import EnglishTranslator, Translator from .zsl import ZeroShotClassifier __all__ = [ "text_classifier", "text_regression_model", "print_text_classifiers", "print_text_regression_models", "texts_from_folder", "texts_from_csv", "texts_from_df", "texts_from_array", "entities_from_gmb", "entities_from_conll2003", "entities_from_txt", "entities_from_array", "entities_from_df", "sequence_tagger", "print_sequence_taggers", "get_topic_model", "Transformer", "TransformerEmbedding", "shallownlp", "TransformerSummarizer", "ZeroShotClassifier", "EnglishTranslator", "Translator", "SimpleQA", "AnswerExtractor", "TextExtractor", "extract_filenames", "load_text_files", def load_topic_model(fname): Load saved TopicModel object Args: fname(str): base filename for all saved files with open(fname + ".tm_vect", "rb") as f: vectorizer = pickle.load(f) with open(fname + ".tm_model", "rb") as f: model = pickle.load(f) with open(fname + ".tm_params", "rb") as f: params = pickle.load(f) tm = get_topic_model( n_topics=params["n_topics"], n_features=params["n_features"], verbose=params["verbose"], tm.model = model tm.vectorizer = vectorizer return tm seqlen_stats = Transformer.seqlen_stats

Sub-modules

ktrain.text.data
ktrain.text.dataset
ktrain.text.eda
ktrain.text.generative_ai
ktrain.text.kw
ktrain.text.learner
ktrain.text.models
ktrain.text.ner
ktrain.text.predictor
ktrain.text.preprocessor
ktrain.text.qa
ktrain.text.sentiment
ktrain.text.shallownlp
ktrain.text.speech
ktrain.text.summarization
ktrain.text.textextractor
ktrain.text.textutils
ktrain.text.translation
ktrain.text.zsl
def entities_from_array ( x_train, y_train, x_test=None, y_test=None, use_char=False, val_pct=0.1, verbose=1)

Load entities from arrays

x_train(list): list of list of entity tokens for training Example: x_train = [['Hello', 'world'], ['Hello', 'Cher'], ['I', 'love', 'Chicago']] y_train(list): list of list of tokens representing entity labels Example: y_train = [['O', 'O'], ['O', 'B-PER'], ['O', 'O', 'B-LOC']] x_test(list): list of list of entity tokens for validation Example: x_train = [['Hello', 'world'], ['Hello', 'Cher'], ['I', 'love', 'Chicago']] y_test(list): list of list of tokens representing entity labels Example: y_train = [['O', 'O'], ['O', 'B-PER'], ['O', 'O', 'B-LOC']] use_char(bool): If True, data will be preprocessed to use character embeddings in addition to word embeddings val_pct(float): percentage of training to use for validation if no validation data is supplied verbose (boolean): verbosity

Expand source code
def entities_from_array(
    x_train, y_train, x_test=None, y_test=None, use_char=False, val_pct=0.1, verbose=1
    Load entities from arrays
    Args:
      x_train(list): list of list of entity tokens for training
                     Example: x_train = [['Hello', 'world'], ['Hello', 'Cher'], ['I', 'love', 'Chicago']]
      y_train(list): list of list of tokens representing entity labels
                     Example:  y_train = [['O', 'O'], ['O', 'B-PER'], ['O', 'O', 'B-LOC']]
      x_test(list): list of list of entity tokens for validation
                     Example: x_train = [['Hello', 'world'], ['Hello', 'Cher'], ['I', 'love', 'Chicago']]
      y_test(list): list of list of tokens representing entity labels
                     Example:  y_train = [['O', 'O'], ['O', 'B-PER'], ['O', 'O', 'B-LOC']]
     use_char(bool):    If True, data will be preprocessed to use character embeddings  in addition to word embeddings
     val_pct(float):  percentage of training to use for validation if no validation data is supplied
     verbose (boolean): verbosity
    # TODO: converting to df to use entities_from_df - needs to be refactored
    train_df = pp.array_to_df(x_train, y_train)
    val_df = None
    if x_test is not None and y_test is not None:
        val_df = pp.array_to_df(x_test, y_test)
    if verbose:
        print("training data sample:")
        print(train_df.head())
        if val_df is not None:
            print("validation data sample:")
            print(val_df.head())
    return entities_from_df(
        train_df, val_df=val_df, val_pct=val_pct, use_char=use_char, verbose=verbose
def entities_from_conll2003(train_filepath, val_filepath=None, use_char=False, encoding=None, val_pct=0.1, verbose=1)

Loads sequence-labeled data from a file in CoNLL2003 format.

Expand source code
def entities_from_conll2003(
    train_filepath,
    val_filepath=None,
    use_char=False,
    encoding=None,
    val_pct=0.1,
    verbose=1,
    Loads sequence-labeled data from a file in CoNLL2003 format.
    return entities_from_txt(
        train_filepath=train_filepath,
        val_filepath=val_filepath,
        use_char=use_char,
        data_format="conll2003",
        encoding=encoding,
        val_pct=val_pct,
        verbose=verbose,
def entities_from_df(train_df, val_df=None, word_column='Word', tag_column='Tag', sentence_column='SentenceID', use_char=False, val_pct=0.1, verbose=1)

Load entities from pandas DataFrame

train_df(pd.DataFrame): training data
val_df(pdf.DataFrame): validation data
word_column(str): name of column containing the text
tag_column(str): name of column containing lael
sentence_column(str): name of column containing Sentence IDs
use_char(bool): If True, data will be preprocessed to use character embeddings in addition to word embeddings
verbose : boolean
verbosity
Expand source code
def entities_from_df(
    train_df,
    val_df=None,
    word_column=WORD_COL,
    tag_column=TAG_COL,
    sentence_column=SENT_COL,
    use_char=False,
    val_pct=0.1,
    verbose=1,
    Load entities from pandas DataFrame
    Args:
      train_df(pd.DataFrame): training data
      val_df(pdf.DataFrame): validation data
      word_column(str): name of column containing the text
      tag_column(str): name of column containing lael
      sentence_column(str): name of column containing Sentence IDs
      use_char(bool):    If True, data will be preprocessed to use character embeddings  in addition to word embeddings
      verbose (boolean): verbosity
    # process dataframe and instantiate NERPreprocessor
    x, y = pp.process_df(
        train_df,
        word_column=word_column,
        tag_column=tag_column,
        sentence_column=sentence_column,
        verbose=verbose,
    # get validation set
    if val_df is None:
        x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=val_pct)
    else:
        x_train, y_train = x, y
        (x_valid, y_valid) = pp.process_df(
            val_df,
            word_column=word_column,
            tag_column=tag_column,
            sentence_column=sentence_column,
            verbose=0,
    # preprocess and convert to generator
    from .anago.preprocessing import IndexTransformer
    p = IndexTransformer(use_char=use_char)
    preproc = NERPreprocessor(p)
    preproc.fit(x_train, y_train)
    from .dataset import NERSequence
    trn = NERSequence(x_train, y_train, batch_size=U.DEFAULT_BS, p=p)
    val = NERSequence(x_valid, y_valid, batch_size=U.DEFAULT_BS, p=p)
    return (trn, val, preproc)
def entities_from_gmb(train_filepath, val_filepath=None, use_char=False, word_column='Word', tag_column='Tag', sentence_column='SentenceID', encoding=None, val_pct=0.1, verbose=1)

Loads sequence-labeled data from text file in the Groningen Meaning Bank (GMB) format.

Expand source code
def entities_from_gmb(
    train_filepath,
    val_filepath=None,
    use_char=False,
    word_column=WORD_COL,
    tag_column=TAG_COL,
    sentence_column=SENT_COL,
    encoding=None,
    val_pct=0.1,
    verbose=1,
    Loads sequence-labeled data from text file in the  Groningen
    Meaning Bank  (GMB) format.
    return entities_from_txt(
        train_filepath=train_filepath,
        val_filepath=val_filepath,
        use_char=use_char,
        word_column=word_column,
        tag_column=tag_column,
        sentence_column=sentence_column,
        data_format="gmb",
        encoding=encoding,
        val_pct=val_pct,
        verbose=verbose,
def entities_from_txt(train_filepath, val_filepath=None, use_char=False, word_column='Word', tag_column='Tag', sentence_column='SentenceID', data_format='conll2003', encoding=None, val_pct=0.1, verbose=1)

Loads sequence-labeled data from comma or tab-delmited text file. Format of file is either the CoNLL2003 format or Groningen Meaning Bank (GMB) format - specified with data_format parameter.

In both formats, each word appars on a separate line along with its associated tag (or label). The last item on each line should be the tag or label assigned to word.

In the CoNLL2003 format, there is an empty line after each sentence. In the GMB format, sentences are deliniated with a third column denoting the Sentence ID.

More information on CoNLL2003 format: https://www.aclweb.org/anthology/W03-0419

CoNLL Example (each column is typically separated by space or tab) no column headings:

B-PER Newman I-PER great actor

More information on GMB format: Refer to ner_dataset.csv on Kaggle here: https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus/version/2

GMB example (each column separated by comma or tab) with column headings:

SentenceID B-PER Newman I-PER great actor

train_filepath(str): file path to training CSV
val_filepath : str
file path to validation dataset
use_char(bool): If True, data will be preprocessed to use character embeddings in addition to word embeddings
word_column(str): name of column containing the text
tag_column(str): name of column containing lael
sentence_column(str): name of column containing Sentence IDs
data_format(str): one of colnll2003 or gmb
word_column, tag_column, and sentence_column
ignored if 'conll2003'
encoding(str): the encoding to use. If None, encoding is discovered automatically
val_pct(float): Proportion of training to use for validation.
verbose : boolean
verbosity
Expand source code
def entities_from_txt(
    train_filepath,
    val_filepath=None,
    use_char=False,
    word_column=WORD_COL,
    tag_column=TAG_COL,
    sentence_column=SENT_COL,
    data_format="conll2003",
    encoding=None,
    val_pct=0.1,
    verbose=1,
    Loads sequence-labeled data from comma or tab-delmited text file.
    Format of file is either the CoNLL2003 format or Groningen Meaning
    Bank (GMB) format - specified with data_format parameter.
    In both formats, each word appars on a separate line along with
    its associated tag (or label).
    The last item on each line should be the tag or label assigned to word.
    In the CoNLL2003 format, there is an empty line after
    each sentence.  In the GMB format, sentences are deliniated
    with a third column denoting the Sentence ID.
    More information on CoNLL2003 format:
       https://www.aclweb.org/anthology/W03-0419
    CoNLL Example (each column is typically separated by space or tab)
    and  no column headings:
       Paul     B-PER
       Newman   I-PER
       is       O
       a        O
       great    O
       actor    O
       !        O
    More information on GMB format:
    Refer to ner_dataset.csv on Kaggle here:
       https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus/version/2
    GMB example (each column separated by comma or tab)
    with column headings:
      SentenceID   Word     Tag
      1            Paul     B-PER
      1            Newman   I-PER
      1            is       O
      1            a        O
      1            great    O
      1            actor    O
      1            !        O
    Args:
        train_filepath(str): file path to training CSV
        val_filepath (str): file path to validation dataset
        use_char(bool):    If True, data will be preprocessed to use character embeddings in addition to word embeddings
        word_column(str): name of column containing the text
        tag_column(str): name of column containing lael
        sentence_column(str): name of column containing Sentence IDs
        data_format(str): one of colnll2003 or gmb
                          word_column, tag_column, and sentence_column
                          ignored if 'conll2003'
        encoding(str): the encoding to use.  If None, encoding is discovered automatically
        val_pct(float): Proportion of training to use for validation.
        verbose (boolean): verbosity
    # set dataframe converter
    if data_format == "gmb":
        data_to_df = pp.gmb_to_df
    else:
        data_to_df = pp.conll2003_to_df
        word_column, tag_column, sentence_column = WORD_COL, TAG_COL, SENT_COL
    # detect encoding
    if encoding is None:
        with open(train_filepath, "rb") as f:
            encoding = TU.detect_encoding(f.read())
            U.vprint(
                "detected encoding: %s (if wrong, set manually)" % (encoding),
                verbose=verbose,
    # create dataframe
    train_df = data_to_df(train_filepath, encoding=encoding)
    val_df = (
        None if val_filepath is None else data_to_df(val_filepath, encoding=encoding)
    return entities_from_df(
        train_df,
        val_df=val_df,
        word_column=word_column,
        tag_column=tag_column,
        sentence_column=sentence_column,
        use_char=use_char,
        val_pct=val_pct,
        verbose=verbose,
def extract_filenames(corpus_path, follow_links=False)
def extract_filenames(corpus_path, follow_links=False):
    if os.listdir(corpus_path) == []:
        raise ValueError("%s: path is empty" % corpus_path)
    walk = os.walk
    for root, dirs, filenames in walk(corpus_path, followlinks=follow_links):
        for filename in filenames:
                yield os.path.join(root, filename)
            except:
                continue
def load_text_files(corpus_path, truncate_len=None, clean=True, return_fnames=False)
load text files
Expand source code
def load_text_files(corpus_path, truncate_len=None, clean=True, return_fnames=False):
    load text files
    texts = []
    filenames = []
    mb = master_bar(range(1))
    for i in mb:
        for filename in progress_bar(list(extract_filenames(corpus_path)), parent=mb):
            with open(filename, "r") as f:
                text = f.read()
            if clean:
                text = strip_control_characters(text)
                text = to_ascii(text)
            if truncate_len is not None:
                text = " ".join(text.split()[:truncate_len])
            texts.append(text)
            filenames.append(filename)
        mb.write("done.")
    if return_fnames:
        return (texts, filenames)
    else:
        return texts
def print_sequence_taggers() def sequence_tagger(name, preproc, wv_path_or_url=None, transformer_model='bert-base-multilingual-cased', transformer_layers_to_use=[-2], bert_model=None, word_embedding_dim=100, char_embedding_dim=25, word_lstm_size=100, char_lstm_size=25, fc_dim=100, dropout=0.5, verbose=1)

Build and return a sequence tagger (i.e., named entity recognizer).

name : string
one of: - 'bilstm-crf' for Bidirectional LSTM-CRF model - 'bilstm' for Bidirectional LSTM (no CRF layer)

preproc(NERPreprocessor): an instance of NERPreprocessor wv_path_or_url(str): either a URL or file path toa fasttext word vector file (.vec or .vec.zip or .vec.gz) Example valid values for wv_path_or_url:

                   Randomly-initialized word embeeddings:
                     set wv_path_or_url=None
                   English pretrained word vectors:
                     <https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip>
                   Chinese pretrained word vectors:
                     <https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz>
                   Russian pretrained word vectors:
                     <https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz>
                   Dutch pretrained word vectors:
                     <https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz>
                 See these two Web pages for a full list of URLs to word vector files for
                 different languages:
                    1.  <https://fasttext.cc/docs/en/english-vectors.html> (for English)
                    2.  <https://fasttext.cc/docs/en/crawl-vectors.html> (for non-English langages)
                Default:None (randomly-initialized word embeddings are used)

transformer_model_name(str): the name of the transformer model. default: 'bert-base-multilingual-cased' This parameter is only used if bilstm-transformer is selected for name parameter. The value of this parameter is a name of transformer model from here: https://huggingface.co/transformers/pretrained_models.html or a community-uploaded BERT model from here: https://huggingface.co/models Example values: bert-base-multilingual-cased: Multilingual BERT (157 languages) - this is the default bert-base-cased: English BERT bert-base-chinese: Chinese BERT distilbert-base-german-cased: German DistilBert albert-base-v2: English ALBERT model monologg/biobert_v1.1_pubmed: community uploaded BioBERT (pretrained on PubMed)

transformer_layers_to_use(list): indices of hidden layers to use. default:[-2] # second-to-last layer
To use the concatenation of last 4 layers: use [-1, -2, -3, -4]
bert_model(str): alias for transformer_model
word_embedding_dim : int
word embedding dimensions.
char_embedding_dim : int
character embedding dimensions.
word_lstm_size : int
character LSTM feature extractor output dimensions.
char_lstm_size : int
word tagger LSTM output dimensions.
fc_dim : int
output fully-connected layer size.
dropout : float
dropout rate.
verbose : boolean
verbosity of output

Return

model (Model): A Keras Model instance

Expand source code
def sequence_tagger(
    name,
    preproc,
    wv_path_or_url=None,
    transformer_model="bert-base-multilingual-cased",
    transformer_layers_to_use=U.DEFAULT_TRANSFORMER_LAYERS,
    bert_model=None,
    word_embedding_dim=100,
    char_embedding_dim=25,
    word_lstm_size=100,
    char_lstm_size=25,
    fc_dim=100,
    dropout=0.5,
    verbose=1,
    Build and return a sequence tagger (i.e., named entity recognizer).
    Args:
        name (string): one of:
                      - 'bilstm-crf' for Bidirectional LSTM-CRF model
                      - 'bilstm' for Bidirectional LSTM (no CRF layer)
        preproc(NERPreprocessor):  an instance of NERPreprocessor
        wv_path_or_url(str): either a URL or file path toa fasttext word vector file (.vec or .vec.zip or .vec.gz)
                             Example valid values for wv_path_or_url:
                               Randomly-initialized word embeeddings:
                                 set wv_path_or_url=None
                               English pretrained word vectors:
                                 https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip
                               Chinese pretrained word vectors:
                                 https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz
                               Russian pretrained word vectors:
                                 https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz
                               Dutch pretrained word vectors:
                                 https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz
                             See these two Web pages for a full list of URLs to word vector files for
                             different languages:
                                1.  https://fasttext.cc/docs/en/english-vectors.html (for English)
                                2.  https://fasttext.cc/docs/en/crawl-vectors.html (for non-English langages)
                            Default:None (randomly-initialized word embeddings are used)
        transformer_model_name(str):  the name of the transformer model.  default: 'bert-base-multilingual-cased'
                                      This parameter is only used if bilstm-transformer is selected for name parameter.
                                       The value of this parameter is a name of transformer model from here:
                                            https://huggingface.co/transformers/pretrained_models.html
                                       or a community-uploaded BERT model from here:
                                           https://huggingface.co/models
                               Example values:
                                 bert-base-multilingual-cased:  Multilingual BERT (157 languages) - this is the default
                                 bert-base-cased:  English BERT
                                 bert-base-chinese: Chinese BERT
                                 distilbert-base-german-cased: German DistilBert
                                 albert-base-v2: English ALBERT model
                                 monologg/biobert_v1.1_pubmed: community uploaded BioBERT (pretrained on PubMed)
        transformer_layers_to_use(list): indices of hidden layers to use.  default:[-2] # second-to-last layer
                                         To use the concatenation of last 4 layers: use [-1, -2, -3, -4]
        bert_model(str): alias for transformer_model
        word_embedding_dim (int): word embedding dimensions.
        char_embedding_dim (int): character embedding dimensions.
        word_lstm_size (int): character LSTM feature extractor output dimensions.
        char_lstm_size (int): word tagger LSTM output dimensions.
        fc_dim (int): output fully-connected layer size.
        dropout (float): dropout rate.
        verbose (boolean): verbosity of output
    Return:
        model (Model): A Keras Model instance
    # backwards compatibility
    name = BILSTM_TRANSFORMER if name == "bilstm-bert" else name
    if bert_model is not None:
        transformer_model = bert_model
        warnings.warn(
            "The bert_model argument is deprecated - please use transformer_model instead.",
            DeprecationWarning,
            stacklevel=2,
    if name not in SEQUENCE_TAGGERS:
        raise ValueError(
            f"Invalid model name {name}. {'Did you mean bilstm-transformer?' if name == 'bilstm-bert' else ''}"
    # check BERT
    if name in TRANSFORMER_MODELS and not transformer_model:
        raise ValueError(
            f"transformer_model is required for {BILSTM_TRANSFORMER} models"
    if name in TRANSFORMER_MODELS and DISABLE_V2_BEHAVIOR:
        raise ValueError(
            "BERT and other transformer models cannot be used with DISABLE_v2_BEHAVIOR"
    # check CRF
    if not DISABLE_V2_BEHAVIOR and name in V1_ONLY_MODELS:
        warnings.warn(
            "Falling back to BiLSTM (no CRF) because DISABLE_V2_BEHAVIOR=False"
        msg = (
            "\nIMPORTANT NOTE: ktrain uses the CRF module from keras_contrib, which is not yet\n"
            + "fully compatible with TensorFlow 2. You can still use the BiLSTM-CRF model\n"
            + "in ktrain for sequence tagging with TensorFlow 2, but you must add the\n"
            + "following to the top of your script or notebook BEFORE you import ktrain:\n\n"
            + "import os\n"
            + "os.environ['DISABLE_V2_BEHAVIOR'] = '1'\n\n"
            + "For this run, a vanilla BiLSTM model (with no CRF layer) will be used.\n"
        print(msg)
        name = BILSTM if name == BILSTM_CRF else BILSTM_ELMO
    # check for use_char=True
    if not DISABLE_V2_BEHAVIOR and preproc.p._use_char:
        # turn off masking due to open TF2 issue ##33148: https://github.com/tensorflow/tensorflow/issues/33148
        warnings.warn(
            "Setting use_char=False:  character embeddings cannot be used in TF2 due to open TensorFlow 2 bug (#33148).\n"
            + 'Add os.environ["DISABLE_V2_BEHAVIOR"] = "1" to the top of script if you really want to use it.'
        preproc.p._use_char = False
    if verbose:
        emb_names = []
        if wv_path_or_url is not None:
            emb_names.append(
                "word embeddings initialized with fasttext word vectors (%s)"
                % (os.path.basename(wv_path_or_url))
        else:
            emb_names.append("word embeddings initialized randomly")
        if name in TRANSFORMER_MODELS:
            emb_names.append("transformer embeddings with " + transformer_model)
        if name in ELMO_MODELS:
            emb_names.append("Elmo embeddings for English")
        if preproc.p._use_char:
            emb_names.append("character embeddings")
        if len(emb_names) > 1:
            print("Embedding schemes employed (combined with concatenation):")
        else:
            print("embedding schemes employed:")
        for emb_name in emb_names:
            print("\t%s" % (emb_name))
        print()
    # setup embedding
    if wv_path_or_url is not None:
        wv_model, word_embedding_dim = preproc.get_wv_model(
            wv_path_or_url, verbose=verbose
    else:
        wv_model = None
    if name == BILSTM_CRF:
        use_crf = False if not DISABLE_V2_BEHAVIOR else True  # fallback to bilstm
    elif name == BILSTM_CRF_ELMO:
        use_crf = False if not DISABLE_V2_BEHAVIOR else True  # fallback to bilstm
        preproc.p.activate_elmo()
    elif name == BILSTM:
        use_crf = False
    elif name == BILSTM_ELMO:
        use_crf = False
        preproc.p.activate_elmo()
    elif name == BILSTM_TRANSFORMER:
        use_crf = False
        preproc.p.activate_transformer(
            transformer_model, layers=transformer_layers_to_use, force=True
    else:
        raise ValueError("Unsupported model name")
    from .anago.models import BiLSTMCRF
    model = BiLSTMCRF(
        char_embedding_dim=char_embedding_dim,
        word_embedding_dim=word_embedding_dim,
        char_lstm_size=char_lstm_size,
        word_lstm_size=word_lstm_size,
        fc_dim=fc_dim,
        char_vocab_size=preproc.p.char_vocab_size,
        word_vocab_size=preproc.p.word_vocab_size,
        num_labels=preproc.p.label_size,
        dropout=dropout,
        use_crf=use_crf,
        use_char=preproc.p._use_char,
        embeddings=wv_model,
        use_elmo=preproc.p.elmo_is_activated(),
        use_transformer_with_dim=preproc.p.get_transformer_dim(),
    model, loss = model.build()
    model.compile(loss=loss, optimizer=U.DEFAULT_OPT)
    return model
def text_classifier(name, train_data, preproc=None, multilabel=None, metrics=None, verbose=1)
Build and return a text classification model.
Args:
    name (string): one of:
                  - 'fasttext' for FastText model
                  - 'nbsvm' for NBSVM model
                  - 'logreg' for logistic regression using embedding layers
                  - 'bigru' for Bidirectional GRU with pretrained word vectors
                  - 'bert' for BERT Text Classification
                  - 'distilbert' for Hugging Face DistilBert model
    train_data (tuple): a tuple of numpy.ndarrays: (x_train, y_train) or ktrain.Dataset instance
                        returned from one of the texts_from_* functions
    preproc: a ktrain.text.TextPreprocessor instance.
             As of v0.8.0, this is required.
    multilabel (bool):  If True, multilabel model will be returned.
                        If false, binary/multiclass model will be returned.
                        If None, multilabel will be inferred from data.
    metrics(list): List of metrics to use.  If None: 'accuracy' is used for binar/multiclassification,
                   'binary_accuracy' is used for multilabel classification, and 'mae' is used for regression.
    verbose (boolean): verbosity of output
Return:
    model (Model): A Keras Model instance
Expand source code
def text_classifier(
    name, train_data, preproc=None, multilabel=None, metrics=None, verbose=1
    Build and return a text classification model.
    Args:
        name (string): one of:
                      - 'fasttext' for FastText model
                      - 'nbsvm' for NBSVM model
                      - 'logreg' for logistic regression using embedding layers
                      - 'bigru' for Bidirectional GRU with pretrained word vectors
                      - 'bert' for BERT Text Classification
                      - 'distilbert' for Hugging Face DistilBert model
        train_data (tuple): a tuple of numpy.ndarrays: (x_train, y_train) or ktrain.Dataset instance
                            returned from one of the texts_from_* functions
        preproc: a ktrain.text.TextPreprocessor instance.
                 As of v0.8.0, this is required.
        multilabel (bool):  If True, multilabel model will be returned.
                            If false, binary/multiclass model will be returned.
                            If None, multilabel will be inferred from data.
        metrics(list): List of metrics to use.  If None: 'accuracy' is used for binar/multiclassification,
                       'binary_accuracy' is used for multilabel classification, and 'mae' is used for regression.
        verbose (boolean): verbosity of output
    Return:
        model (Model): A Keras Model instance
    if name not in TEXT_CLASSIFIERS:
        raise ValueError("invalid name for text classification: %s" % (name))
    if preproc is not None and not preproc.get_classes():
        raise ValueError(
            "preproc.get_classes() is empty, but required for text classification"
    return _text_model(
        name,
        train_data,
        preproc=preproc,
        multilabel=multilabel,
        classification=True,
        metrics=metrics,
        verbose=verbose,
def text_regression_model(name, train_data, preproc=None, metrics=['mae'], verbose=1)
Build and return a text regression model.
Args:
    name (string): one of:
                  - 'fasttext' for FastText model
                  - 'nbsvm' for NBSVM model
                  - 'linreg' for linear regression using embedding layers
                  - 'bigru' for Bidirectional GRU with pretrained word vectors
                  - 'bert' for BERT Text Classification
                  - 'distilbert' for Hugging Face DistilBert model
    train_data (tuple): a tuple of numpy.ndarrays: (x_train, y_train)
    preproc: a ktrain.text.TextPreprocessor instance.
             As of v0.8.0, this is required.
    metrics(list): metrics to use
    verbose (boolean): verbosity of output
Return:
    model (Model): A Keras Model instance
Expand source code
def text_regression_model(name, train_data, preproc=None, metrics=["mae"], verbose=1):
    Build and return a text regression model.
    Args:
        name (string): one of:
                      - 'fasttext' for FastText model
                      - 'nbsvm' for NBSVM model
                      - 'linreg' for linear regression using embedding layers
                      - 'bigru' for Bidirectional GRU with pretrained word vectors
                      - 'bert' for BERT Text Classification
                      - 'distilbert' for Hugging Face DistilBert model
        train_data (tuple): a tuple of numpy.ndarrays: (x_train, y_train)
        preproc: a ktrain.text.TextPreprocessor instance.
                 As of v0.8.0, this is required.
        metrics(list): metrics to use
        verbose (boolean): verbosity of output
    Return:
        model (Model): A Keras Model instance
    if name not in TEXT_REGRESSION_MODELS:
        raise ValueError("invalid name for text classification: %s" % (name))
    if preproc is not None and preproc.get_classes():
        raise ValueError(
            "preproc.get_classes() is supposed to be empty for text regression tasks"
    return _text_model(
        name,
        train_data,
        preproc=preproc,
        multilabel=False,
        classification=False,
        metrics=metrics,
        verbose=verbose,
def texts_from_array(x_train, y_train, x_test=None, y_test=None, class_names=[], max_features=20000, maxlen=400, val_pct=0.1, ngram_range=1, preprocess_mode='standard', lang=None, random_state=None, verbose=1)
Loads and preprocesses text data from arrays.
texts_from_array can handle data for both text classification
and text regression.  If class_names is empty, a regression task is assumed.
Args:
    x_train(list): list of training texts
    y_train(list): labels in one of the following forms:
                   1. list of integers representing classes (class_names is required)
                   2. list of strings representing classes (class_names is not needed and ignored.)
                   3. a one or multi hot encoded array representing classes (class_names is required)
                   4. numerical values for text regresssion (class_names should be left empty)
    x_test(list): list of training texts
    y_test(list): labels in one of the following forms:
                   1. list of integers representing classes (class_names is required)
                   2. list of strings representing classes (class_names is not needed and ignored.)
                   3. a one or multi hot encoded array representing classes (class_names is required)
                   4. numerical values for text regresssion (class_names should be left empty)
    class_names (list): list of strings representing class labels
                        shape should be (num_examples,1) or (num_examples,)
    max_features(int): max num of words to consider in vocabulary
                       Note: This is only used for preprocess_mode='standard'.
    maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
    ngram_range(int): size of multi-word phrases to consider
                      e.g., 2 will consider both 1-word phrases and 2-word phrases
                           limited by max_features
    val_pct(float): Proportion of training to use for validation.
                    Has no effect if x_val and  y_val is supplied.
    preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                            tokenization and preprocessing for use with
                            BERT/DistilBert text classification model.
    lang (str):            language.  Auto-detected if None.
    random_state(int):      If integer is supplied, train/test split is reproducible.
                            If None, train/test split will be random.
    verbose (boolean): verbosity
Expand source code
def texts_from_array(
    x_train,
    y_train,
    x_test=None,
    y_test=None,
    class_names=[],
    max_features=MAX_FEATURES,
    maxlen=MAXLEN,
    val_pct=0.1,
    ngram_range=1,
    preprocess_mode="standard",
    lang=None,  # auto-detected
    random_state=None,
    verbose=1,
    Loads and preprocesses text data from arrays.
    texts_from_array can handle data for both text classification
    and text regression.  If class_names is empty, a regression task is assumed.
    Args:
        x_train(list): list of training texts
        y_train(list): labels in one of the following forms:
                       1. list of integers representing classes (class_names is required)
                       2. list of strings representing classes (class_names is not needed and ignored.)
                       3. a one or multi hot encoded array representing classes (class_names is required)
                       4. numerical values for text regresssion (class_names should be left empty)
        x_test(list): list of training texts
        y_test(list): labels in one of the following forms:
                       1. list of integers representing classes (class_names is required)
                       2. list of strings representing classes (class_names is not needed and ignored.)
                       3. a one or multi hot encoded array representing classes (class_names is required)
                       4. numerical values for text regresssion (class_names should be left empty)
        class_names (list): list of strings representing class labels
                            shape should be (num_examples,1) or (num_examples,)
        max_features(int): max num of words to consider in vocabulary
                           Note: This is only used for preprocess_mode='standard'.
        maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
        ngram_range(int): size of multi-word phrases to consider
                          e.g., 2 will consider both 1-word phrases and 2-word phrases
                               limited by max_features
        val_pct(float): Proportion of training to use for validation.
                        Has no effect if x_val and  y_val is supplied.
        preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                                tokenization and preprocessing for use with
                                BERT/DistilBert text classification model.
        lang (str):            language.  Auto-detected if None.
        random_state(int):      If integer is supplied, train/test split is reproducible.
                                If None, train/test split will be random.
        verbose (boolean): verbosity
    U.check_array(x_train, y=y_train, X_name="x_train", y_name="y_train")
    if x_test is None or y_test is None:
        x_train, x_test, y_train, y_test = train_test_split(
            x_train, y_train, test_size=val_pct, random_state=random_state
    else:
        U.check_array(x_test, y=y_test, X_name="x_test", y_name="y_test")
    # removed as TextPreprocessor now handles this.
    # if isinstance(y_train[0], str):
    # if not isinstance(y_test[0], str):
    # raise ValueError('y_train contains strings, but y_test does not')
    # encoder = LabelEncoder()
    # encoder.fit(y_train)
    # y_train = encoder.transform(y_train)
    # y_test = encoder.transform(y_test)
    # detect language
    if lang is None:
        lang = TU.detect_lang(x_train)
    check_unsupported_lang(lang, preprocess_mode)
    # return preprocessed the texts
    preproc_type = tpp.TEXT_PREPROCESSORS.get(preprocess_mode, None)
    if None:
        raise ValueError("unsupported preprocess_mode")
    preproc = preproc_type(
        maxlen,
        max_features,
        class_names=class_names,
        lang=lang,
        ngram_range=ngram_range,
    trn = preproc.preprocess_train(x_train, y_train, verbose=verbose)
    val = preproc.preprocess_test(x_test, y_test, verbose=verbose)
    if not preproc.get_classes() and verbose:
        print(
            "task: text regression (supply class_names argument if this is supposed to be classification task)"
    else:
        print("task: text classification")
    return (trn, val, preproc)
def texts_from_csv(train_filepath, text_column, label_columns=[], val_filepath=None, max_features=20000, maxlen=400, val_pct=0.1, ngram_range=1, preprocess_mode='standard', encoding=None, lang=None, sep=',', is_regression=False, random_state=None, verbose=1)
Loads text data from CSV or TSV file. Class labels are assumed to be
one of the following formats:
    1. one-hot-encoded or multi-hot-encoded arrays representing classes:
          Example with label_columns=['positive', 'negative'] and text_column='text':
            text|positive|negative
            I like this movie.|1|0
            I hated this movie.|0|1
        Classification will have a single one in each row: [[1,0,0], [0,1,0]]]
        Multi-label classification will have one more ones in each row: [[1,1,0], [0,1,1]]
    2. labels are in a single column of string or integer values representing classs labels
           Example with label_columns=['label'] and text_column='text':
             text|label
             I like this movie.|positive
             I hated this movie.|negative
   3. labels are a single column of numerical values for text regression
      NOTE: Must supply is_regression=True for labels to be treated as numerical targets
             wine_description|wine_price
             Exquisite wine!|100
             Wine for budget shoppers|8
Args:
    train_filepath(str): file path to training CSV
    text_column(str): name of column containing the text
    label_column(list): list of columns that are to be treated as labels
    val_filepath(string): file path to test CSV.  If not supplied,
                           10% of documents in training CSV will be
                           used for testing/validation.
    max_features(int): max num of words to consider in vocabulary
                       Note: This is only used for preprocess_mode='standard'.
    maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
    ngram_range(int): size of multi-word phrases to consider
                      e.g., 2 will consider both 1-word phrases and 2-word phrases
                           limited by max_features
    val_pct(float): Proportion of training to use for validation.
                    Has no effect if val_filepath is supplied.
    preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                            tokenization and preprocessing for use with
                            BERT/DistilBert text classification model.
    encoding (str):        character encoding to use. Auto-detected if None
    lang (str):            language.  Auto-detected if None.
    sep(str):              delimiter for CSV (comma is default)
    is_regression(bool):  If True, integer targets will be treated as numerical targets instead of class IDs
    random_state(int):      If integer is supplied, train/test split is reproducible.
                            If None, train/test split will be random
    verbose (boolean): verbosity
Expand source code
def texts_from_csv(
    train_filepath,
    text_column,
    label_columns=[],
    val_filepath=None,
    max_features=MAX_FEATURES,
    maxlen=MAXLEN,
    val_pct=0.1,
    ngram_range=1,
    preprocess_mode="standard",
    encoding=None,  # auto-detected
    lang=None,  # auto-detected
    sep=",",
    is_regression=False,
    random_state=None,
    verbose=1,
    Loads text data from CSV or TSV file. Class labels are assumed to be
    one of the following formats:
        1. one-hot-encoded or multi-hot-encoded arrays representing classes:
              Example with label_columns=['positive', 'negative'] and text_column='text':
                text|positive|negative
                I like this movie.|1|0
                I hated this movie.|0|1
            Classification will have a single one in each row: [[1,0,0], [0,1,0]]]
            Multi-label classification will have one more ones in each row: [[1,1,0], [0,1,1]]
        2. labels are in a single column of string or integer values representing classs labels
               Example with label_columns=['label'] and text_column='text':
                 text|label
                 I like this movie.|positive
                 I hated this movie.|negative
       3. labels are a single column of numerical values for text regression
          NOTE: Must supply is_regression=True for labels to be treated as numerical targets
                 wine_description|wine_price
                 Exquisite wine!|100
                 Wine for budget shoppers|8
    Args:
        train_filepath(str): file path to training CSV
        text_column(str): name of column containing the text
        label_column(list): list of columns that are to be treated as labels
        val_filepath(string): file path to test CSV.  If not supplied,
                               10% of documents in training CSV will be
                               used for testing/validation.
        max_features(int): max num of words to consider in vocabulary
                           Note: This is only used for preprocess_mode='standard'.
        maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
        ngram_range(int): size of multi-word phrases to consider
                          e.g., 2 will consider both 1-word phrases and 2-word phrases
                               limited by max_features
        val_pct(float): Proportion of training to use for validation.
                        Has no effect if val_filepath is supplied.
        preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                                tokenization and preprocessing for use with
                                BERT/DistilBert text classification model.
        encoding (str):        character encoding to use. Auto-detected if None
        lang (str):            language.  Auto-detected if None.
        sep(str):              delimiter for CSV (comma is default)
        is_regression(bool):  If True, integer targets will be treated as numerical targets instead of class IDs
        random_state(int):      If integer is supplied, train/test split is reproducible.
                                If None, train/test split will be random
        verbose (boolean): verbosity
    if encoding is None:
        with open(train_filepath, "rb") as f:
            encoding = TU.detect_encoding(f.read())
            U.vprint(
                "detected encoding: %s (if wrong, set manually)" % (encoding),
                verbose=verbose,
    train_df = pd.read_csv(train_filepath, encoding=encoding, sep=sep)
    val_df = (
        pd.read_csv(val_filepath, encoding=encoding, sep=sep)
        if val_filepath is not None
        else None
    return texts_from_df(
        train_df,
        text_column,
        label_columns=label_columns,
        val_df=val_df,
        max_features=max_features,
        maxlen=maxlen,
        val_pct=val_pct,
        ngram_range=ngram_range,
        preprocess_mode=preprocess_mode,
        lang=lang,
        is_regression=is_regression,
        random_state=random_state,
        verbose=verbose,
def texts_from_df(train_df, text_column, label_columns=[], val_df=None, max_features=20000, maxlen=400, val_pct=0.1, ngram_range=1, preprocess_mode='standard', lang=None, is_regression=False, random_state=None, verbose=1)
Loads text data from Pandas dataframe file. Class labels are assumed to be
one of the following formats:
    1. one-hot-encoded or multi-hot-encoded arrays representing classes:
          Example with label_columns=['positive', 'negative'] and text_column='text':
            text|positive|negative
            I like this movie.|1|0
            I hated this movie.|0|1
        Classification will have a single one in each row: [[1,0,0], [0,1,0]]]
        Multi-label classification will have one more ones in each row: [[1,1,0], [0,1,1]]
    2. labels are in a single column of string or integer values representing class labels
           Example with label_columns=['label'] and text_column='text':
             text|label
             I like this movie.|positive
             I hated this movie.|negative
   3. labels are a single column of numerical values for text regression
      NOTE: Must supply is_regression=True for integer labels to be treated as numerical targets
             wine_description|wine_price
             Exquisite wine!|100
             Wine for budget shoppers|8
Args:
    train_df(dataframe): Pandas dataframe
    text_column(str): name of column containing the text
    label_columns(list): list of columns that are to be treated as labels
    val_df(dataframe): file path to test dataframe.  If not supplied,
                           10% of documents in training df will be
                           used for testing/validation.
    max_features(int): max num of words to consider in vocabulary.
                       Note: This is only used for preprocess_mode='standard'.
    maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
    ngram_range(int): size of multi-word phrases to consider
                      e.g., 2 will consider both 1-word phrases and 2-word phrases
                           limited by max_features
    val_pct(float): Proportion of training to use for validation.
                    Has no effect if val_filepath is supplied.
    preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                            tokenization and preprocessing for use with
                            BERT/DistilBert text classification model.
    lang (str):            language.  Auto-detected if None.
    is_regression(bool):  If True, integer targets will be treated as numerical targets instead of class IDs
    random_state(int):      If integer is supplied, train/test split is reproducible.
                            If None, train/test split will be random
    verbose (boolean): verbosity
Expand source code
def texts_from_df(
    train_df,
    text_column,
    label_columns=[],
    val_df=None,
    max_features=MAX_FEATURES,
    maxlen=MAXLEN,
    val_pct=0.1,
    ngram_range=1,
    preprocess_mode="standard",
    lang=None,  # auto-detected
    is_regression=False,
    random_state=None,
    verbose=1,
    Loads text data from Pandas dataframe file. Class labels are assumed to be
    one of the following formats:
        1. one-hot-encoded or multi-hot-encoded arrays representing classes:
              Example with label_columns=['positive', 'negative'] and text_column='text':
                text|positive|negative
                I like this movie.|1|0
                I hated this movie.|0|1
            Classification will have a single one in each row: [[1,0,0], [0,1,0]]]
            Multi-label classification will have one more ones in each row: [[1,1,0], [0,1,1]]
        2. labels are in a single column of string or integer values representing class labels
               Example with label_columns=['label'] and text_column='text':
                 text|label
                 I like this movie.|positive
                 I hated this movie.|negative
       3. labels are a single column of numerical values for text regression
          NOTE: Must supply is_regression=True for integer labels to be treated as numerical targets
                 wine_description|wine_price
                 Exquisite wine!|100
                 Wine for budget shoppers|8
    Args:
        train_df(dataframe): Pandas dataframe
        text_column(str): name of column containing the text
        label_columns(list): list of columns that are to be treated as labels
        val_df(dataframe): file path to test dataframe.  If not supplied,
                               10% of documents in training df will be
                               used for testing/validation.
        max_features(int): max num of words to consider in vocabulary.
                           Note: This is only used for preprocess_mode='standard'.
        maxlen(int): each document can be of most <maxlen> words. 0 is used as padding ID.
        ngram_range(int): size of multi-word phrases to consider
                          e.g., 2 will consider both 1-word phrases and 2-word phrases
                               limited by max_features
        val_pct(float): Proportion of training to use for validation.
                        Has no effect if val_filepath is supplied.
        preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                                tokenization and preprocessing for use with
                                BERT/DistilBert text classification model.
        lang (str):            language.  Auto-detected if None.
        is_regression(bool):  If True, integer targets will be treated as numerical targets instead of class IDs
        random_state(int):      If integer is supplied, train/test split is reproducible.
                                If None, train/test split will be random
        verbose (boolean): verbosity
    # read in train and test data
    train_df = train_df.copy()
    train_df[text_column].fillna("fillna", inplace=True)
    if val_df is not None:
        val_df = val_df.copy()
        val_df[text_column].fillna("fillna", inplace=True)
    else:
        train_df, val_df = train_test_split(
            train_df, test_size=val_pct, random_state=random_state
    # transform labels
    ytransdf = U.YTransformDataFrame(label_columns, is_regression=is_regression)
    t_df = ytransdf.apply_train(train_df)
    v_df = ytransdf.apply_test(val_df)
    class_names = ytransdf.get_classes()
    new_lab_cols = ytransdf.get_label_columns(squeeze=True)
    x_train = t_df[text_column].values
    y_train = t_df[new_lab_cols].values
    x_test = v_df[text_column].values
    y_test = v_df[new_lab_cols].values
    # detect language
    if lang is None:
        lang = TU.detect_lang(x_train)
    check_unsupported_lang(lang, preprocess_mode)
    # return preprocessed the texts
    preproc_type = tpp.TEXT_PREPROCESSORS.get(preprocess_mode, None)
    if None:
        raise ValueError("unsupported preprocess_mode")
    preproc = preproc_type(
        maxlen,
        max_features,
        class_names=class_names,
        lang=lang,
        ngram_range=ngram_range,
    trn = preproc.preprocess_train(x_train, y_train, verbose=verbose)
    val = preproc.preprocess_test(x_test, y_test, verbose=verbose)
    # QUICKFIX for #314
    preproc.ytransform.le = ytransdf.le
    return (trn, val, preproc)
def texts_from_folder(datadir, classes=None, max_features=20000, maxlen=400, ngram_range=1, train_test_names=['train', 'test'], preprocess_mode='standard', encoding=None, lang=None, val_pct=0.1, random_state=None, verbose=1)
Returns corpus as sequence of word IDs.
Assumes corpus is in the following folder structure:
├── datadir
│   ├── train
│   │   ├── class0       # folder containing documents of class 0
│   │   ├── class1       # folder containing documents of class 1
│   │   ├── class2       # folder containing documents of class 2
│   │   └── classN       # folder containing documents of class N
│   └── test
│       ├── class0       # folder containing documents of class 0
│       ├── class1       # folder containing documents of class 1
│       ├── class2       # folder containing documents of class 2
│       └── classN       # folder containing documents of class N
Each subfolder should contain documents in plain text format.
If train and test contain additional subfolders that do not represent
classes, they can be ignored by explicitly listing the subfolders of
interest using the classes argument.
Args:
    datadir (str): path to folder
    classes (list): list of classes (subfolders to consider).
                    This is simply supplied as the categories argument
                    to sklearn's load_files function.
    max_features (int):  maximum number of unigrams to consider
                         Note: This is only used for preprocess_mode='standard'.
    maxlen (int):  maximum length of tokens in document
    ngram_range (int):  If > 1, will include 2=bigrams, 3=trigrams and bigrams
    train_test_names (list):  list of strings represnting the subfolder
                             name for train and validation sets
                             if test name is missing, <val_pct> of training
                             will be used for validation
    preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                            tokenization and preprocessing for use with
                            BERT/DistilBert text classification model.
    encoding (str):        character encoding to use. Auto-detected if None
    lang (str):            language.  Auto-detected if None.
    val_pct(float):        Onlyl used if train_test_names  has 1 and not 2 names
    random_state(int):      If integer is supplied, train/test split is reproducible.
                            IF None, train/test split will be random
    verbose (bool):         verbosity
    train_test_names=["train", "test"],
    preprocess_mode="standard",
    encoding=None,  # detected automatically
    lang=None,  # detected automatically
    val_pct=0.1,
    random_state=None,
    verbose=1,
    Returns corpus as sequence of word IDs.
    Assumes corpus is in the following folder structure:
    ├── datadir
    │   ├── train
    │   │   ├── class0       # folder containing documents of class 0
    │   │   ├── class1       # folder containing documents of class 1
    │   │   ├── class2       # folder containing documents of class 2
    │   │   └── classN       # folder containing documents of class N
    │   └── test
    │       ├── class0       # folder containing documents of class 0
    │       ├── class1       # folder containing documents of class 1
    │       ├── class2       # folder containing documents of class 2
    │       └── classN       # folder containing documents of class N
    Each subfolder should contain documents in plain text format.
    If train and test contain additional subfolders that do not represent
    classes, they can be ignored by explicitly listing the subfolders of
    interest using the classes argument.
    Args:
        datadir (str): path to folder
        classes (list): list of classes (subfolders to consider).
                        This is simply supplied as the categories argument
                        to sklearn's load_files function.
        max_features (int):  maximum number of unigrams to consider
                             Note: This is only used for preprocess_mode='standard'.
        maxlen (int):  maximum length of tokens in document
        ngram_range (int):  If > 1, will include 2=bigrams, 3=trigrams and bigrams
        train_test_names (list):  list of strings represnting the subfolder
                                 name for train and validation sets
                                 if test name is missing, <val_pct> of training
                                 will be used for validation
        preprocess_mode (str):  Either 'standard' (normal tokenization) or one of {'bert', 'distilbert'}
                                tokenization and preprocessing for use with
                                BERT/DistilBert text classification model.
        encoding (str):        character encoding to use. Auto-detected if None
        lang (str):            language.  Auto-detected if None.
        val_pct(float):        Onlyl used if train_test_names  has 1 and not 2 names
        random_state(int):      If integer is supplied, train/test split is reproducible.
                                IF None, train/test split will be random
        verbose (bool):         verbosity
    # check train_test_names
    if len(train_test_names) < 1 or len(train_test_names) > 2:
        raise ValueError(
            "train_test_names must have 1 or two elements for train and optionally validation"
    # read in training and test corpora
    train_str = train_test_names[0]
    train_b = load_files(
        os.path.join(datadir, train_str), shuffle=True, categories=classes
    if len(train_test_names) > 1:
        test_str = train_test_names[1]
        test_b = load_files(
            os.path.join(datadir, test_str), shuffle=False, categories=classes
        x_train = train_b.data
        y_train = train_b.target
        x_test = test_b.data
        y_test = test_b.target
    else:
        x_train, x_test, y_train, y_test = train_test_split(
            train_b.data, train_b.target, test_size=val_pct, random_state=random_state
    # decode based on supplied encoding
    if encoding is None:
        encoding = TU.detect_encoding(x_train)
        U.vprint("detected encoding: %s" % (encoding), verbose=verbose)
        x_train = [x.decode(encoding) for x in x_train]
        x_test = [x.decode(encoding) for x in x_test]
    except:
        U.vprint(
            "Decoding with %s failed 1st attempt - using %s with skips"
            % (encoding, encoding),
            verbose=verbose,
        x_train = TU.decode_by_line(x_train, encoding=encoding, verbose=verbose)
        x_test = TU.decode_by_line(x_test, encoding=encoding, verbose=verbose)
    # detect language
    if lang is None:
        lang = TU.detect_lang(x_train)
    check_unsupported_lang(lang, preprocess_mode)
    # return preprocessed the texts
    preproc_type = tpp.TEXT_PREPROCESSORS.get(preprocess_mode, None)
    if None:
        raise ValueError("unsupported preprocess_mode")
    preproc = preproc_type(
        maxlen,
        max_features,
        class_names=train_b.target_names,
        lang=lang,
        ngram_range=ngram_range,
    trn = preproc.preprocess_train(x_train, y_train, verbose=verbose)
    val = preproc.preprocess_test(x_test, y_test, verbose=verbose)
    return (trn, val, preproc)
class AnswerExtractor (model_name='bert-large-uncased-whole-word-masking-finetuned-squad', bert_squad_model=None, framework='tf', device=None, quantize=False)

Question-Answering-based Information Extraction

Extracts information from documents using Question-Answering.
  model_name(str): name of Question-Answering model (e.g., BERT SQUAD) to use
  bert_squad_model(str): alias for model_name (deprecated)
  framework(str): 'tf' for TensorFlow or 'pt' for PyTorch
  device(str): Torch device to use (e.g., 'cuda', 'cpu'). Ignored if framework=='tf'.
               If framework=='tf', use CUDA_VISIBLE_DEVICES environment variable
               to select device.
  quantize(bool): If True and framework=='pt' and device != 'cpu', then faster quantized inference is used.
              Ignored if framework=="tf".
Expand source code
class AnswerExtractor:
    Question-Answering-based Information Extraction
    def __init__(
        self,
        model_name=DEFAULT_MODEL,
        bert_squad_model=None,
        framework="tf",
        device=None,
        quantize=False,
        Extracts information from documents using Question-Answering.
          model_name(str): name of Question-Answering model (e.g., BERT SQUAD) to use
          bert_squad_model(str): alias for model_name (deprecated)
          framework(str): 'tf' for TensorFlow or 'pt' for PyTorch
          device(str): Torch device to use (e.g., 'cuda', 'cpu'). Ignored if framework=='tf'.
                       If framework=='tf', use CUDA_VISIBLE_DEVICES environment variable
                       to select device.
          quantize(bool): If True and framework=='pt' and device != 'cpu', then faster quantized inference is used.
                      Ignored if framework=="tf".
        self.qa = _QAExtractor(
            model_name=model_name,
            bert_squad_model=bert_squad_model,
            framework=framework,
            device=device,
            quantize=quantize,
        return
    def _check_columns(self, labels, df):
        """check columns"""
        cols = df.columns.values
        for l in labels:
            if l in cols:
                raise ValueError(
                    "There is already a column named %s in your DataFrame." % (l)
    def _extract(
        self,
        questions,
        contexts,
        min_conf=DEFAULT_MIN_CONF,
        return_conf=False,
        batch_size=8,
        Extracts answers
        num_rows = len(contexts)
        doc_results = [
            {"rawtext": rawtext, "reference": row}
            for row, rawtext in enumerate(contexts)
        cols = []
        for q in questions:
            result_dict = {}
            conf_dict = {}
            answers = self.qa.ask(q, doc_results=doc_results, batch_size=batch_size)
            for a in answers:
                answer = a["answer"] if a["confidence"] > min_conf else None
                lst = result_dict.get(a["reference"], [])
                lst.append(answer)
                result_dict[a["reference"]] = lst
                lst = conf_dict.get(a["reference"], [])
                lst.append(a["confidence"])
                conf_dict[a["reference"]] = lst
            results = []
            for i in range(num_rows):
                ans = [a for a in result_dict[i] if a is not None]
                results.append(None if not ans else " | ".join(ans))
            cols.append(results)
            if return_conf:
                confs = []
                for i in range(num_rows):
                    conf = [str(round(c, 2)) for c in conf_dict[i] if c is not None]
                    confs.append(None if not conf else " | ".join(conf))
                cols.append(confs)
        return cols
    def extract(
        self,
        texts,
        question_label_pairs,
        min_conf=DEFAULT_MIN_CONF,
        return_conf=False,
        batch_size=8,
        Extracts answers from texts
        Args:
          texts(list): list of strings
          df(pd.DataFrame): original DataFrame to which columns need to be added
          question_label_pairs(list):  A list of tuples of the form (question, label).
                                     Extracted ansewrs to the question will be added as new columns with the
                                     specified labels.
                                     Example: ('What are the risk factors?', 'Risk Factors')
          min_conf(float):  Answers at or below this confidence value will be set to None in the results
                            Default: 5.0
                            Lower this value to reduce false negatives.
                            Raise this value to reduce false positives.
          return_conf(bool): If True, confidence score of each extraction is included in results
          batch_size(int): batch size. Default: 8
        if not isinstance(df, pd.DataFrame):
            raise ValueError("df must be a pandas DataFrame.")
        if len(texts) != df.shape[0]:
            raise ValueError(
                "Number of texts is not equal to the number of rows in the DataFrame."
        # texts = [t.replace("\n", " ").replace("\t", " ") for t in texts]
        texts = [t.replace("\t", " ") for t in texts]
        questions = [q for q, l in question_label_pairs]
        labels = [l for q, l in question_label_pairs]
        self._check_columns(labels, df)
        cols = self._extract(
            questions,
            texts,
            min_conf=min_conf,
            return_conf=return_conf,
            batch_size=batch_size,
        data = list(zip(*cols)) if len(cols) > 1 else cols[0]
        if return_conf:
            labels = twolists(labels, [l + " CONF" for l in labels])
        return df.join(pd.DataFrame(data, columns=labels, index=df.index))
    def finetune(
        self, data, epochs=3, learning_rate=2e-5, batch_size=8, max_seq_length=512
        Finetune a QA model.
        Args:
          data(list): list of dictionaries of the form:
                      [{'question': 'What is ktrain?'
                       'context': 'ktrain is a low-code library for augmented machine learning.'
                       'answer': 'ktrain'}]
          epochs(int): number of epochs.  Default:3
          learning_rate(float): learning rate.  Default: 2e-5
          batch_size(int): batch size. Default:8
          max_seq_length(int): maximum sequence length.  Default:512
        Returns:
        if self.qa.framework != "tf":
            raise ValueError(
                'The finetune method does not currently support the framework="pt" option. Please use framework="tf" to finetune.'
        from .qa_finetuner import QAFineTuner
        ft = QAFineTuner(self.qa.model, self.qa.tokenizer)
        model = ft.finetune(
            data, epochs=epochs, learning_rate=learning_rate, batch_size=batch_size
        return

Methods

def extract(self, texts, df, question_label_pairs, min_conf=6, return_conf=False, batch_size=8)
Extracts answers from texts
Args:
  texts(list): list of strings
  df(pd.DataFrame): original DataFrame to which columns need to be added
  question_label_pairs(list):  A list of tuples of the form (question, label).
                             Extracted ansewrs to the question will be added as new columns with the
                             specified labels.
                             Example: ('What are the risk factors?', 'Risk Factors')
  min_conf(float):  Answers at or below this confidence value will be set to None in the results
                    Default: 5.0
                    Lower this value to reduce false negatives.
                    Raise this value to reduce false positives.
  return_conf(bool): If True, confidence score of each extraction is included in results
  batch_size(int): batch size. Default: 8
Expand source code
def extract(
    self,
    texts,
    question_label_pairs,
    min_conf=DEFAULT_MIN_CONF,
    return_conf=False,
    batch_size=8,
    Extracts answers from texts
    Args:
      texts(list): list of strings
      df(pd.DataFrame): original DataFrame to which columns need to be added
      question_label_pairs(list):  A list of tuples of the form (question, label).
                                 Extracted ansewrs to the question will be added as new columns with the
                                 specified labels.
                                 Example: ('What are the risk factors?', 'Risk Factors')
      min_conf(float):  Answers at or below this confidence value will be set to None in the results
                        Default: 5.0
                        Lower this value to reduce false negatives.
                        Raise this value to reduce false positives.
      return_conf(bool): If True, confidence score of each extraction is included in results
      batch_size(int): batch size. Default: 8
    if not isinstance(df, pd.DataFrame):
        raise ValueError("df must be a pandas DataFrame.")
    if len(texts) != df.shape[0]:
        raise ValueError(
            "Number of texts is not equal to the number of rows in the DataFrame."
    # texts = [t.replace("\n", " ").replace("\t", " ") for t in texts]
    texts = [t.replace("\t", " ") for t in texts]
    questions = [q for q, l in question_label_pairs]
    labels = [l for q, l in question_label_pairs]
    self._check_columns(labels, df)
    cols = self._extract(
        questions,
        texts,
        min_conf=min_conf,
        return_conf=return_conf,
        batch_size=batch_size,
    data = list(zip(*cols)) if len(cols) > 1 else cols[0]
    if return_conf:
        labels = twolists(labels, [l + " CONF" for l in labels])
    return df.join(pd.DataFrame(data, columns=labels, index=df.index))
def finetune(self, data, epochs=3, learning_rate=2e-05, batch_size=8, max_seq_length=512)
Finetune a QA model.
Args:
  data(list): list of dictionaries of the form:
              [{'question': 'What is ktrain?'
               'context': 'ktrain is a low-code library for augmented machine learning.'
               'answer': 'ktrain'}]
  epochs(int): number of epochs.  Default:3
  learning_rate(float): learning rate.  Default: 2e-5
  batch_size(int): batch size. Default:8
  max_seq_length(int): maximum sequence length.  Default:512
Returns:
Expand source code
def finetune(
    self, data, epochs=3, learning_rate=2e-5, batch_size=8, max_seq_length=512
    Finetune a QA model.
    Args:
      data(list): list of dictionaries of the form:
                  [{'question': 'What is ktrain?'
                   'context': 'ktrain is a low-code library for augmented machine learning.'
                   'answer': 'ktrain'}]
      epochs(int): number of epochs.  Default:3
      learning_rate(float): learning rate.  Default: 2e-5
      batch_size(int): batch size. Default:8
      max_seq_length(int): maximum sequence length.  Default:512
    Returns:
    if self.qa.framework != "tf":
        raise ValueError(
            'The finetune method does not currently support the framework="pt" option. Please use framework="tf" to finetune.'
    from .qa_finetuner import QAFineTuner
    ft = QAFineTuner(self.qa.model, self.qa.tokenizer)
    model = ft.finetune(
        data, epochs=epochs, learning_rate=learning_rate, batch_size=batch_size
    return

Class to translate text in various languages to English.

Constructor for English translator
Args:
  src_lang(str): language code of source language.
                 Must be one of SUPPORTED_SRC_LANGS:
                   'zh': Chinese (either tradtional or simplified)
                   'ar': Arabic
                   'ru' : Russian
                   'de': German
                   'af': Afrikaans
                   'es': Spanish
                   'fr': French
                   'it': Italian
                   'pt': Portuguese
  device(str): device to use (e.g., 'cuda', 'cpu')
  quantize(bool): If True, use quantization.
Expand source code
class EnglishTranslator:
    Class to translate text in various languages to English.
    def __init__(self, src_lang=None, device=None, quantize=False):
        Constructor for English translator
        Args:
          src_lang(str): language code of source language.
                         Must be one of SUPPORTED_SRC_LANGS:
                           'zh': Chinese (either tradtional or simplified)
                           'ar': Arabic
                           'ru' : Russian
                           'de': German
                           'af': Afrikaans
                           'es': Spanish
                           'fr': French
                           'it': Italian
                           'pt': Portuguese
          device(str): device to use (e.g., 'cuda', 'cpu')
          quantize(bool): If True, use quantization.
        if src_lang is None or src_lang not in SUPPORTED_SRC_LANGS:
            raise ValueError(
                "A src_lang must be supplied and be one of: %s" % (SUPPORTED_SRC_LANGS)
        self.src_lang = src_lang
        self.translators = []
        if src_lang == "ar":
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-ar-en",
                    device=device,
                    quantize=quantize,
        elif src_lang == "ru":
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-ru-en",
                    device=device,
                    quantize=quantize,
        elif src_lang == "de":
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-de-en",
                    device=device,
                    quantize=quantize,
        elif src_lang == "af":
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-af-en",
                    device=device,
                    quantize=quantize,
        elif src_lang in ["es", "fr", "it", "pt"]:
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-ROMANCE-en",
                    device=device,
                    quantize=quantize,
        # elif src_lang == 'zh': # could not find zh->en model, so currently doing two-step translation to English via German
        # self.translators.append(Translator(model_name='Helsinki-NLP/opus-mt-ZH-de', device=device))
        # self.translators.append(Translator(model_name='Helsinki-NLP/opus-mt-de-en', device=device))
        elif src_lang == "zh":
            self.translators.append(
                Translator(
                    model_name="Helsinki-NLP/opus-mt-zh-en",
                    device=device,
                    quantize=quantize,
        else:
            raise ValueError("lang:%s is currently not supported." % (src_lang))
    def translate(self, src_text, join_with="\n", num_beams=1, early_stopping=False):
        Translate source document to English.
        To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
        Args:
          src_text(str): source text. Must be in language specified by src_lang (language code) supplied to constructor
                         The source text can either be a single sentence or an entire document with multiple sentences
                         and paragraphs.
                         IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                         If the input text is very large (e.g., an entire book), you should
                                         break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                         feed each chunk separately into translate to avoid out-of-memory issues.
          join_with(str):  list of translated sentences will be delimited with this character.
                           default: each sentence on separate line
          num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                          whicn means no beam search.
          early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                                 are finished per batch or not. Defaults to None.  If None, the transformers library
                                 sets this to False.
        Returns:
          str: translated text
        text = src_text
        for t in self.translators:
            text = t.translate(
                text,
                join_with=join_with,
                num_beams=num_beams,
                early_stopping=early_stopping,
        return text

Methods

def translate(self, src_text, join_with='\n', num_beams=1, early_stopping=False)
Translate source document to English.
To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
Args:
  src_text(str): source text. Must be in language specified by src_lang (language code) supplied to constructor
                 The source text can either be a single sentence or an entire document with multiple sentences
                 and paragraphs.
                 IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                 If the input text is very large (e.g., an entire book), you should
                                 break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                 feed each chunk separately into translate to avoid out-of-memory issues.
  join_with(str):  list of translated sentences will be delimited with this character.
                   default: each sentence on separate line
  num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                  whicn means no beam search.
  early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                         are finished per batch or not. Defaults to None.  If None, the transformers library
                         sets this to False.
Returns:
  str: translated text
Expand source code
def translate(self, src_text, join_with="\n", num_beams=1, early_stopping=False):
    Translate source document to English.
    To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
    Args:
      src_text(str): source text. Must be in language specified by src_lang (language code) supplied to constructor
                     The source text can either be a single sentence or an entire document with multiple sentences
                     and paragraphs.
                     IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                     If the input text is very large (e.g., an entire book), you should
                                     break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                     feed each chunk separately into translate to avoid out-of-memory issues.
      join_with(str):  list of translated sentences will be delimited with this character.
                       default: each sentence on separate line
      num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                      whicn means no beam search.
      early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                             are finished per batch or not. Defaults to None.  If None, the transformers library
                             sets this to False.
    Returns:
      str: translated text
    text = src_text
    for t in self.translators:
        text = t.translate(
            text,
            join_with=join_with,
            num_beams=num_beams,
            early_stopping=early_stopping,
    return text
class SimpleQA (index_dir, model_name='bert-large-uncased-whole-word-masking-finetuned-squad', bert_squad_model=None, bert_emb_model='bert-base-uncased', framework='tf', device=None, quantize=False)

SimpleQA: Question-Answering on a list of texts

SimpleQA constructor
Args:
  index_dir(str):  path to index directory created by SimpleQA.initialze_index
  model_name(str): name of Question-Answering model (e.g., BERT SQUAD) to use
  bert_squad_model(str): alias for model_name (deprecated)
  bert_emb_model(str): BERT model to use to generate embeddings for semantic similarity
  framework(str): 'tf' for TensorFlow or 'pt' for PyTorch
  device(str): Torch device to use (e.g., 'cuda', 'cpu'). Ignored if framework=='tf'.
               If framework=='tf', use CUDA_VISIBLE_DEVICES environment variable
               to select device.
  quantize(bool): If True and framework=='pt' and device != 'cpu', then faster quantized inference is used.
              Ignored if framework=="tf".
Expand source code
class SimpleQA(ExtractiveQABase):
    SimpleQA: Question-Answering on a list of texts
    def __init__(
        self,
        index_dir,
        model_name=DEFAULT_MODEL,
        bert_squad_model=None,  # deprecated
        bert_emb_model="bert-base-uncased",
        framework="tf",
        device=None,
        quantize=False,
        SimpleQA constructor
        Args:
          index_dir(str):  path to index directory created by SimpleQA.initialze_index
          model_name(str): name of Question-Answering model (e.g., BERT SQUAD) to use
          bert_squad_model(str): alias for model_name (deprecated)
          bert_emb_model(str): BERT model to use to generate embeddings for semantic similarity
          framework(str): 'tf' for TensorFlow or 'pt' for PyTorch
          device(str): Torch device to use (e.g., 'cuda', 'cpu'). Ignored if framework=='tf'.
                       If framework=='tf', use CUDA_VISIBLE_DEVICES environment variable
                       to select device.
          quantize(bool): If True and framework=='pt' and device != 'cpu', then faster quantized inference is used.
                      Ignored if framework=="tf".
        self.index_dir = index_dir
            ix = index.open_dir(self.index_dir)
        except:
            raise ValueError(
                'index_dir has not yet been created - please call SimpleQA.initialize_index("%s")'
                % (self.index_dir)
        super().__init__(
            model_name=model_name,
            bert_squad_model=bert_squad_model,
            bert_emb_model=bert_emb_model,
            framework=framework,
            device=device,
            quantize=quantize,
    def _open_ix(self):
        return index.open_dir(self.index_dir)
    @classmethod
    def initialize_index(cls, index_dir):
        schema = Schema(
            reference=ID(stored=True), content=TEXT, rawtext=TEXT(stored=True)
        if not os.path.exists(index_dir):
            os.makedirs(index_dir)
        else:
            raise ValueError(
                "There is already an existing directory or file with path %s"
                % (index_dir)
        ix = index.create_in(index_dir, schema)
        return ix
    @classmethod
    def index_from_list(
        docs,
        index_dir,
        commit_every=1024,
        breakup_docs=True,
        procs=1,
        limitmb=256,
        multisegment=False,
        min_words=20,
        references=None,
        index documents from list.
        The procs, limitmb, and especially multisegment arguments can be used to
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          docs(list): list of strings representing documents
          index_dir(str): path to index directory (see initialize_index)
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                            Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                               Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                            Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                               Example: ['ktrain_article\thttps://arxiv.org/pdf/2004.10703v4.pdf', ...]
                            These references will be returned in the output of the ask method.
                            If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                            displays candidate answers in a pandas DataFRame.
                            If references is None, the index of element in docs is used as reference.
        if not isinstance(docs, (np.ndarray, list)):
            raise ValueError("docs must be a list of strings")
        if references is not None and not isinstance(references, (np.ndarray, list)):
            raise ValueError("references must be a list of strings")
        if references is not None and len(references) != len(docs):
            raise ValueError("lengths of docs and references must be equal")
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        mb = master_bar(range(1))
        for i in mb:
            for idx, doc in enumerate(progress_bar(docs, parent=mb)):
                reference = "%s" % (idx) if references is None else references[idx]
                if breakup_docs:
                    small_docs = TU.paragraph_tokenize(
                        doc, join_sentences=True, lang="en"
                    refs = [reference] * len(small_docs)
                    for i, small_doc in enumerate(small_docs):
                        if len(small_doc.split()) < min_words:
                            continue
                        content = small_doc
                        reference = refs[i]
                        writer.add_document(
                            reference=reference, content=content, rawtext=content
                else:
                    if len(doc.split()) < min_words:
                        continue
                    content = doc
                    writer.add_document(
                        reference=reference, content=content, rawtext=content
                idx += 1
                if idx % commit_every == 0:
                    writer.commit()
                    # writer = ix.writer()
                    writer = ix.writer(
                        procs=procs, limitmb=limitmb, multisegment=multisegment
                mb.child.comment = f"indexing documents"
            writer.commit()
            # mb.write(f'Finished indexing documents')
        return
    @classmethod
    def index_from_folder(
        folder_path,
        index_dir,
        use_text_extraction=False,
        commit_every=1024,
        breakup_docs=True,
        min_words=20,
        encoding="utf-8",
        procs=1,
        limitmb=256,
        multisegment=False,
        verbose=1,
        index all plain text documents within a folder.
        The procs, limitmb, and especially multisegment arguments can be used to
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          folder_path(str): path to folder containing plain text documents (e.g., .txt files)
          index_dir(str): path to index directory (see initialize_index)
          use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                     file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                     If False, only plain text files will be indexed.
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          encoding(str): encoding to use when reading document files from disk
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          verbose(bool): verbosity
        if use_text_extraction:
            # TODO:  change this to use TextExtractor
                import textract
            except ImportError:
                raise Exception(
                    "use_text_extraction=True requires textract:   pip install textract"
        if not os.path.isdir(folder_path):
            raise ValueError("folder_path is not a valid folder")
        if folder_path[-1] != os.sep:
            folder_path += os.sep
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        for idx, fpath in enumerate(TU.extract_filenames(folder_path)):
            reference = "%s" % (fpath.join(fpath.split(folder_path)[1:]))
            if TU.is_txt(fpath):
                with open(fpath, "r", encoding=encoding) as f:
                    doc = f.read()
            else:
                if use_text_extraction:
                        doc = textract.process(fpath)
                        doc = doc.decode("utf-8", "ignore")
                    except:
                        if verbose:
                            warnings.warn("Could not extract text from %s" % (fpath))
                        continue
                else:
                    continue
            if breakup_docs:
                small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang="en")
                refs = [reference] * len(small_docs)
                for i, small_doc in enumerate(small_docs):
                    if len(small_doc.split()) < min_words:
                        continue
                    content = small_doc
                    reference = refs[i]
                    writer.add_document(
                        reference=reference, content=content, rawtext=content
            else:
                if len(doc.split()) < min_words:
                    continue
                content = doc
                writer.add_document(
                    reference=reference, content=content, rawtext=content
            idx += 1
            if idx % commit_every == 0:
                writer.commit()
                writer = ix.writer(
                    procs=procs, limitmb=limitmb, multisegment=multisegment
                if verbose:
                    print("%s docs indexed" % (idx))
        writer.commit()
        return
    def search(self, query, limit=10):
        search index for query
        Args:
          query(str): search query
          limit(int):  number of top search results to return
        Returns:
          list of dicts with keys: reference, rawtext
        ix = self._open_ix()
        with ix.searcher() as searcher:
            query_obj = QueryParser("content", ix.schema, group=qparser.OrGroup).parse(
                query
            results = searcher.search(query_obj, limit=limit)
            docs = []
            output = [dict(r) for r in results]
            return output

Ancestors

  • ExtractiveQABase
  • abc.ABC
  • TorchBase
  • Static methods

    def index_from_folder(folder_path, index_dir, use_text_extraction=False, commit_every=1024, breakup_docs=True, min_words=20, encoding='utf-8', procs=1, limitmb=256, multisegment=False, verbose=1)
    index all plain text documents within a folder.
    The procs, limitmb, and especially multisegment arguments can be used to
    speed up indexing, if it is too slow.  Please see the whoosh documentation
    for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
    Args:
      folder_path(str): path to folder containing plain text documents (e.g., .txt files)
      index_dir(str): path to index directory (see initialize_index)
      use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                 file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                 If False, only plain text files will be indexed.
      commit_every(int): commet after adding this many documents
      breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                          This can potentially improve the speed at which answers are returned by the ask method
                          when documents being searched are longer.
      min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                       Useful for pruning contexts that are unlikely to contain useful answers
      encoding(str): encoding to use when reading document files from disk
      procs(int): number of processors
      limitmb(int): memory limit in MB for each process
      multisegment(bool): new segments written instead of merging
      verbose(bool): verbosity
    Expand source code
    
    @classmethod
    def index_from_folder(
        folder_path,
        index_dir,
        use_text_extraction=False,
        commit_every=1024,
        breakup_docs=True,
        min_words=20,
        encoding="utf-8",
        procs=1,
        limitmb=256,
        multisegment=False,
        verbose=1,
        index all plain text documents within a folder.
        The procs, limitmb, and especially multisegment arguments can be used to
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          folder_path(str): path to folder containing plain text documents (e.g., .txt files)
          index_dir(str): path to index directory (see initialize_index)
          use_text_extraction(bool): If True, the  `textract` package will be used to index text from various
                                     file types including PDF, MS Word, and MS PowerPoint (in addition to plain text files).
                                     If False, only plain text files will be indexed.
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          encoding(str): encoding to use when reading document files from disk
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          verbose(bool): verbosity
        if use_text_extraction:
            # TODO:  change this to use TextExtractor
                import textract
            except ImportError:
                raise Exception(
                    "use_text_extraction=True requires textract:   pip install textract"
        if not os.path.isdir(folder_path):
            raise ValueError("folder_path is not a valid folder")
        if folder_path[-1] != os.sep:
            folder_path += os.sep
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        for idx, fpath in enumerate(TU.extract_filenames(folder_path)):
            reference = "%s" % (fpath.join(fpath.split(folder_path)[1:]))
            if TU.is_txt(fpath):
                with open(fpath, "r", encoding=encoding) as f:
                    doc = f.read()
            else:
                if use_text_extraction:
                        doc = textract.process(fpath)
                        doc = doc.decode("utf-8", "ignore")
                    except:
                        if verbose:
                            warnings.warn("Could not extract text from %s" % (fpath))
                        continue
                else:
                    continue
            if breakup_docs:
                small_docs = TU.paragraph_tokenize(doc, join_sentences=True, lang="en")
                refs = [reference] * len(small_docs)
                for i, small_doc in enumerate(small_docs):
                    if len(small_doc.split()) < min_words:
                        continue
                    content = small_doc
                    reference = refs[i]
                    writer.add_document(
                        reference=reference, content=content, rawtext=content
            else:
                if len(doc.split()) < min_words:
                    continue
                content = doc
                writer.add_document(
                    reference=reference, content=content, rawtext=content
            idx += 1
            if idx % commit_every == 0:
                writer.commit()
                writer = ix.writer(
                    procs=procs, limitmb=limitmb, multisegment=multisegment
                if verbose:
                    print("%s docs indexed" % (idx))
        writer.commit()
        return
    def index_from_list(docs, index_dir, commit_every=1024, breakup_docs=True, procs=1, limitmb=256, multisegment=False, min_words=20, references=None)
    index documents from list.
    The procs, limitmb, and especially multisegment arguments can be used to
    speed up indexing, if it is too slow.  Please see the whoosh documentation
    for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
    Args:
      docs(list): list of strings representing documents
      index_dir(str): path to index directory (see initialize_index)
      commit_every(int): commet after adding this many documents
      breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                          This can potentially improve the speed at which answers are returned by the ask method
                          when documents being searched are longer.
      procs(int): number of processors
      limitmb(int): memory limit in MB for each process
      multisegment(bool): new segments written instead of merging
      min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                       Useful for pruning contexts that are unlikely to contain useful answers
      references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                        Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                           Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                        Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                           Example: ['ktrain_article        https://arxiv.org/pdf/2004.10703v4.pdf', ...]
                        These references will be returned in the output of the ask method.
                        If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                        displays candidate answers in a pandas DataFRame.
                        If references is None, the index of element in docs is used as reference.
    Expand source code
    
    @classmethod
    def index_from_list(
        docs,
        index_dir,
        commit_every=1024,
        breakup_docs=True,
        procs=1,
        limitmb=256,
        multisegment=False,
        min_words=20,
        references=None,
        index documents from list.
        The procs, limitmb, and especially multisegment arguments can be used to
        speed up indexing, if it is too slow.  Please see the whoosh documentation
        for more information on these parameters:  https://whoosh.readthedocs.io/en/latest/batch.html
        Args:
          docs(list): list of strings representing documents
          index_dir(str): path to index directory (see initialize_index)
          commit_every(int): commet after adding this many documents
          breakup_docs(bool): break up documents into smaller paragraphs and treat those as the documents.
                              This can potentially improve the speed at which answers are returned by the ask method
                              when documents being searched are longer.
          procs(int): number of processors
          limitmb(int): memory limit in MB for each process
          multisegment(bool): new segments written instead of merging
          min_words(int):  minimum words for a document (or paragraph extracted from document when breakup_docs=True) to be included in index.
                           Useful for pruning contexts that are unlikely to contain useful answers
          references(list): List of strings containing a reference (e.g., file name) for each document in docs.
                            Each string is treated as a label for the document (e.g., file name, MD5 hash, etc.):
                               Example:  ['some_file.pdf', 'some_other_file,pdf', ...]
                            Strings can also be hyperlinks in which case the label and URL should be separated by a single tab character:
                               Example: ['ktrain_article\thttps://arxiv.org/pdf/2004.10703v4.pdf', ...]
                            These references will be returned in the output of the ask method.
                            If strings are  hyperlinks, then they will automatically be made clickable when the display_answers function
                            displays candidate answers in a pandas DataFRame.
                            If references is None, the index of element in docs is used as reference.
        if not isinstance(docs, (np.ndarray, list)):
            raise ValueError("docs must be a list of strings")
        if references is not None and not isinstance(references, (np.ndarray, list)):
            raise ValueError("references must be a list of strings")
        if references is not None and len(references) != len(docs):
            raise ValueError("lengths of docs and references must be equal")
        ix = index.open_dir(index_dir)
        writer = ix.writer(procs=procs, limitmb=limitmb, multisegment=multisegment)
        mb = master_bar(range(1))
        for i in mb:
            for idx, doc in enumerate(progress_bar(docs, parent=mb)):
                reference = "%s" % (idx) if references is None else references[idx]
                if breakup_docs:
                    small_docs = TU.paragraph_tokenize(
                        doc, join_sentences=True, lang="en"
                    refs = [reference] * len(small_docs)
                    for i, small_doc in enumerate(small_docs):
                        if len(small_doc.split()) < min_words:
                            continue
                        content = small_doc
                        reference = refs[i]
                        writer.add_document(
                            reference=reference, content=content, rawtext=content
                else:
                    if len(doc.split()) < min_words:
                        continue
                    content = doc
                    writer.add_document(
                        reference=reference, content=content, rawtext=content
                idx += 1
                if idx % commit_every == 0:
                    writer.commit()
                    # writer = ix.writer()
                    writer = ix.writer(
                        procs=procs, limitmb=limitmb, multisegment=multisegment
                mb.child.comment = f"indexing documents"
            writer.commit()
            # mb.write(f'Finished indexing documents')
        return
    def initialize_index(index_dir) def initialize_index(cls, index_dir): schema = Schema( reference=ID(stored=True), content=TEXT, rawtext=TEXT(stored=True) if not os.path.exists(index_dir): os.makedirs(index_dir) else: raise ValueError( "There is already an existing directory or file with path %s" % (index_dir) ix = index.create_in(index_dir, schema) return ix

    Methods

    def search(self, query, limit=10)
    search index for query
    Args:
      query(str): search query
      limit(int):  number of top search results to return
    Returns:
      list of dicts with keys: reference, rawtext
    Expand source code
    
    def search(self, query, limit=10):
        search index for query
        Args:
          query(str): search query
          limit(int):  number of top search results to return
        Returns:
          list of dicts with keys: reference, rawtext
        ix = self._open_ix()
        with ix.searcher() as searcher:
            query_obj = QueryParser("content", ix.schema, group=qparser.OrGroup).parse(
                query
            results = searcher.search(query_obj, limit=limit)
            docs = []
            output = [dict(r) for r in results]
            return output

    Inherited members

  • ExtractiveQABase:
  • predict_squad
  • quantize_model
  • except ImportError as e: raise ValueError( "If use_tika=True, then TextExtractor requires tika: pip install tika" except PermissionError as e: raise PermissionError( f"There may already be a /tmp/tika.log file from another user - please delete it or change permissions: {e}" if not use_tika and not TEXTRACT_INSTALLED: raise ValueError( "If use_tika=False, then TextExtractor requires textract: pip install textract" self.use_tika = use_tika def extract( self, filename=None, text=None, return_format="document", lang=None, verbose=1 Extracts text from document given file path to document. filename(str): path to file, Mutually-exclusive with text. text(str): string to tokenize. Mutually-exclusive with filename. The extract method can also simply accept a string and return lists of sentences or paragraphs. return_format(str): One of {'document', 'paragraphs', 'sentences'} 'document': returns text of document 'paragraphs': returns a list of paragraphs from document 'sentences': returns a list of sentences from document lang(str): language code. If None, lang will be detected from extracted text verbose(bool): verbosity if filename is None and text is None: raise ValueError( "Either the filename parameter or the text parameter must be supplied" if filename is not None and text is not None: raise ValueError("The filename and text parameters are mutually-exclusive.") if return_format not in ["document", "paragraphs", "sentences"]: raise ValueError( 'return_format must be one of {"document", "paragraphs", "sentences"}' if filename is not None: mtype = TU.get_mimetype(filename) if mtype and mtype.split("/")[0] == "text": with open(filename, "r") as f: text = f.read() text = str.encode(text) else: text = self._extract(filename) except Exception as e: if verbose: print("ERROR on %s:\n%s" % (filename, e)) text = text.decode(errors="ignore") except: if return_format == "sentences": return TU.sent_tokenize(text, lang=lang) elif return_format == "paragraphs": return TU.paragraph_tokenize(text, join_sentences=True, lang=lang) else: return text def _extract(self, filename): if self.use_tika: from tika import parser if JAVA_INSTALLED: parsed = parser.from_file(filename) text = parsed["content"] else: raise Exception("Please install Java for TIKA text extraction") else: text = textract.process(filename) return text.strip()

    Methods

    def extract(self, filename=None, text=None, return_format='document', lang=None, verbose=1)
    Extracts text from document given file path to document.
    filename(str): path to file,  Mutually-exclusive with text.
    text(str): string to tokenize.  Mutually-exclusive with filename.
               The extract method can also simply accept a string and return lists of sentences or paragraphs.
    return_format(str): One of {'document', 'paragraphs', 'sentences'}
                      'document': returns text of document
                      'paragraphs': returns a list of paragraphs from document
                      'sentences': returns a list of sentences from document
    lang(str): language code. If None, lang will be detected from extracted text
    verbose(bool): verbosity
    Expand source code
    
    def extract(
        self, filename=None, text=None, return_format="document", lang=None, verbose=1
        Extracts text from document given file path to document.
        filename(str): path to file,  Mutually-exclusive with text.
        text(str): string to tokenize.  Mutually-exclusive with filename.
                   The extract method can also simply accept a string and return lists of sentences or paragraphs.
        return_format(str): One of {'document', 'paragraphs', 'sentences'}
                          'document': returns text of document
                          'paragraphs': returns a list of paragraphs from document
                          'sentences': returns a list of sentences from document
        lang(str): language code. If None, lang will be detected from extracted text
        verbose(bool): verbosity
        if filename is None and text is None:
            raise ValueError(
                "Either the filename parameter or the text parameter must be supplied"
        if filename is not None and text is not None:
            raise ValueError("The filename and text parameters are mutually-exclusive.")
        if return_format not in ["document", "paragraphs", "sentences"]:
            raise ValueError(
                'return_format must be one of {"document", "paragraphs", "sentences"}'
        if filename is not None:
            mtype = TU.get_mimetype(filename)
                if mtype and mtype.split("/")[0] == "text":
                    with open(filename, "r") as f:
                        text = f.read()
                        text = str.encode(text)
                else:
                    text = self._extract(filename)
            except Exception as e:
                if verbose:
                    print("ERROR on %s:\n%s" % (filename, e))
            text = text.decode(errors="ignore")
        except:
        if return_format == "sentences":
            return TU.sent_tokenize(text, lang=lang)
        elif return_format == "paragraphs":
            return TU.paragraph_tokenize(text, join_sentences=True, lang=lang)
        else:
            return text
    class get_topic_model (texts=None, n_topics=None, n_features=10000, min_df=5, max_df=0.5, stop_words='english', model_type='lda', max_iter=5, lda_max_iter=None, lda_mode='online', token_pattern=None, verbose=1, hyperparam_kwargs=None)

    Fits a topic model to documents in .

    Example

    tm = ktrain.text.get_topic_model(docs, n_topics=20, n_features=1000, min_df=2, max_df=0.95)

    texts : list of str
    list of texts
    n_topics : int
    number of topics. If None, n_topics = min{400, sqrt[# documents/2]})
    n_features : int
    maximum words to consider
    max_df : float
    words in more than max_df proportion of docs discarded
    stop_words : str or list
    either 'english' for built-in stop words or a list of stop words to ignore
    model_type(str): type of topic model to fit. One of {'lda', 'nmf'}. Default:'lda'
    max_iter : int
    maximum iterations. 5 is default if using lda_mode='online' or nmf. If lda_mode='batch', this should be increased (e.g., 1500).
    lda_max_iter : int
    alias for max_iter for backwards compatilibity
    lda_mode : str
    one of {'online', 'batch'}. Ignored if model_type !='lda'

    token_pattern(str): regex pattern to use to tokenize documents. verbose(bool): verbosity hyperparam_kwargs(dict): hyperparameters for LDA/NMF Keys in this dict can be any of the following: alpha: alpha for LDA default: 5./n_topics beta: beta for LDA. default:0.01 nmf_alpha: alias for alpha for backwars compatilibity l1_ratio: l1_ratio for NMF. default: 0 ngram_range: whether to consider bigrams, trigrams. default: (1,1)

    Expand source code
    class TopicModel:
        def __init__(
            self,
            texts=None,
            n_topics=None,
            n_features=10000,
            min_df=5,
            max_df=0.5,
            stop_words="english",
            model_type="lda",
            max_iter=5,
            lda_max_iter=None,
            lda_mode="online",
            token_pattern=None,
            verbose=1,
            hyperparam_kwargs=None,
            Fits a topic model to documents in <texts>.
            Example:
                tm = ktrain.text.get_topic_model(docs, n_topics=20,
                                                n_features=1000, min_df=2, max_df=0.95)
            Args:
                texts (list of str): list of texts
                n_topics (int): number of topics.
                                If None, n_topics = min{400, sqrt[# documents/2]})
                n_features (int):  maximum words to consider
                max_df (float): words in more than max_df proportion of docs discarded
                stop_words (str or list): either 'english' for built-in stop words or
                                          a list of stop words to ignore
                model_type(str): type of topic model to fit. One of {'lda', 'nmf'}.  Default:'lda'
                max_iter (int): maximum iterations.  5 is default if using lda_mode='online' or nmf.
                                    If lda_mode='batch', this should be increased (e.g., 1500).
                lda_max_iter (int): alias for max_iter for backwards compatilibity
                lda_mode (str):  one of {'online', 'batch'}. Ignored if model_type !='lda'
                token_pattern(str): regex pattern to use to tokenize documents.
                verbose(bool): verbosity
                hyperparam_kwargs(dict): hyperparameters for LDA/NMF
                                         Keys in this dict can be any of the following:
                                             alpha: alpha for LDA  default: 5./n_topics
                                             beta: beta for LDA.  default:0.01
                                             nmf_alpha: alias for alpha for backwars compatilibity
                                             l1_ratio: l1_ratio for NMF. default: 0
                                             ngram_range:  whether to consider bigrams, trigrams. default: (1,1)
            self.verbose = verbose
            # estimate n_topics
            if n_topics is None:
                if texts is None:
                    raise ValueError("If n_topics is None, texts must be supplied")
                estimated = max(1, int(math.floor(math.sqrt(len(texts) / 2))))
                n_topics = min(400, estimated)
                if verbose:
                    print("n_topics automatically set to %s" % (n_topics))
            # train model
            if texts is not None:
                (model, vectorizer) = self.train(
                    texts,
                    model_type=model_type,
                    n_topics=n_topics,
                    n_features=n_features,
                    min_df=min_df,
                    max_df=max_df,
                    stop_words=stop_words,
                    max_iter=max_iter,
                    lda_max_iter=lda_max_iter,
                    lda_mode=lda_mode,
                    token_pattern=token_pattern,
                    hyperparam_kwargs=hyperparam_kwargs,
            else:
                vectorizer = None
                model = None
            # save model and vectorizer and hyperparameter settings
            self.vectorizer = vectorizer
            self.model = model
            self.n_topics = n_topics
            self.n_features = n_features
            if verbose:
                print("done.")
            # these variables are set by self.build():
            self.topic_dict = None
            self.doc_topics = None
            self.bool_array = None
            self.scorer = None  # set by self.train_scorer()
            self.recommender = None  # set by self.train_recommender()
            return
        def train(
            self,
            texts,
            model_type="lda",
            n_topics=None,
            n_features=10000,
            min_df=5,
            max_df=0.5,
            stop_words="english",
            max_iter=5,
            lda_max_iter=None,
            lda_mode="online",
            token_pattern=None,
            hyperparam_kwargs=None,
            Fits a topic model to documents in <texts>.
            Example:
                tm = ktrain.text.get_topic_model(docs, n_topics=20,
                                                n_features=1000, min_df=2, max_df=0.95)
            Args:
                texts (list of str): list of texts
                n_topics (int): number of topics.
                                If None, n_topics = min{400, sqrt[# documents/2]})
                n_features (int):  maximum words to consider
                max_df (float): words in more than max_df proportion of docs discarded
                stop_words (str or list): either 'english' for built-in stop words or
                                         a list of stop words to ignore
                max_iter (int): maximum iterations for 'lda'.  5 is default if using lda_mode='online'.
                lda_max_iter (int): alias for max_iter for backwards compatibility
                                    If lda_mode='batch', this should be increased (e.g., 1500).
                                    Ignored if model_type != 'lda'
                lda_mode (str):  one of {'online', 'batch'}. Ignored of model_type !='lda'
                token_pattern(str): regex pattern to use to tokenize documents.
                                    If None, a default tokenizer will be used
                hyperparam_kwargs(dict): hyperparameters for LDA/NMF
                                         Keys in this dict can be any of the following:
                                             alpha: alpha for LDA  default: 5./n_topics
                                             beta: beta for LDA.  default:0.01
                                             nmf_alpha_W: alpha for NMF alpha_W (default is 0.0)
                                             nmf_alpha_H: alpha for NMF alpha_H (default is 'same')
                                             l1_ratio: l1_ratio for NMF. default: 0
                                             ngram_range:  whether to consider bigrams, trigrams. default: (1,1)
            Returns:
                tuple: (model, vectorizer)
            max_iter = lda_max_iter if lda_max_iter is not None else max_iter
            if hyperparam_kwargs is None:
                hyperparam_kwargs = {}
            alpha = hyperparam_kwargs.get("alpha", 5.0 / n_topics)
            nmf_alpha_W = hyperparam_kwargs.get("nmf_alpha_W", 0.0)
            nmf_alpha_H = hyperparam_kwargs.get("nmf_alpha_H", "same")
            beta = hyperparam_kwargs.get("beta", 0.01)
            l1_ratio = hyperparam_kwargs.get("l1_ratio", 0)
            ngram_range = hyperparam_kwargs.get("ngram_range", (1, 1))
            # adjust defaults based on language detected
            if texts is not None:
                lang = TU.detect_lang(texts)
                if lang != "en":
                    stopwords = None if stop_words == "english" else stop_words
                    token_pattern = (
                        r"(?u)\b\w+\b" if token_pattern is None else token_pattern
                if pp.is_nospace_lang(lang):
                    text_list = []
                    for t in texts:
                        text_list.append(" ".join(jieba.cut(t, HMM=False)))
                    texts = text_list
                if self.verbose:
                    print("lang: %s" % (lang))
            # preprocess texts
            if self.verbose:
                print("preprocessing texts...")
            if token_pattern is None:
                token_pattern = TU.DEFAULT_TOKEN_PATTERN
            # if token_pattern is None: token_pattern = r'(?u)\b\w\w+\b'
            vectorizer = CountVectorizer(
                max_df=max_df,
                min_df=min_df,
                max_features=n_features,
                stop_words=stop_words,
                token_pattern=token_pattern,
                ngram_range=ngram_range,
            x_train = vectorizer.fit_transform(texts)
            # fit model
            if self.verbose:
                print("fitting model...")
            if model_type == "lda":
                model = LatentDirichletAllocation(
                    n_components=n_topics,
                    max_iter=max_iter,
                    learning_method=lda_mode,
                    learning_offset=50.0,
                    doc_topic_prior=alpha,
                    topic_word_prior=beta,
                    verbose=self.verbose,
                    random_state=0,
            elif model_type == "nmf":
                model = NMF(
                    n_components=n_topics,
                    max_iter=max_iter,
                    verbose=self.verbose,
                    alpha_W=nmf_alpha_W,
                    alpha_H=nmf_alpha_H,
                    l1_ratio=l1_ratio,
                    random_state=0,
            else:
                raise ValueError("unknown model type:", str(model_type))
            model.fit(x_train)
            # save model and vectorizer and hyperparameter settings
            return (model, vectorizer)
        @property
        def topics(self):
            convenience method/property
            return self.get_topics()
        def get_document_topic_distribution(self):
            Gets the document-topic distribution.
            Each row is a document and each column is a topic
            The output of this method is equivalent to invoking get_doctopics with no arguments.
            self._check_build()
            return self.doc_topics
        def get_sorted_docs(self, topic_id):
            Returns all docs sorted by relevance to <topic_id>.
            Unlike get_docs, this ranks documents by the supplied topic_id rather
            than the topic_id to which document is most relevant.
            docs = self.get_docs()
            d = {}
            for doc in docs:
                d[doc["doc_id"]] = doc
            m = self.get_document_topic_distribution()
            doc_ids = (-m[:, topic_id]).argsort()
            return [d[doc_id] for doc_id in doc_ids]
        def get_word_weights(self, topic_id, n_words=100):
            Returns a list tuples of the form: (word, weight) for given topic_id.
            The weight can be interpreted as the number of times word was assigned to topic with given topic_id.
            REFERENCE: https://stackoverflow.com/a/48890889/13550699
            Args:
                topic_id(int): topic ID
                n_words=int): number of top words
            self._check_model()
            if topic_id + 1 > len(self.model.components_):
                raise ValueError(
                    "topic_id must be less than %s" % (len(self.model.components_))
            feature_names = self.vectorizer.get_feature_names_out()
            word_probs = self.model.components_[topic_id]
            word_ids = [i for i in word_probs.argsort()[: -n_words - 1 : -1]]
            words = [feature_names[i] for i in word_ids]
            probs = [word_probs[i] for i in word_ids]
            return list(zip(words, probs))
        def get_topics(self, n_words=10, as_string=True, show_counts=False):
            Returns a list of discovered topics
            Args:
                n_words(int): number of words to use in topic summary
                as_string(bool): If True, each summary is a space-delimited string instead of list of words
                show_counts(bool): If True, returns list of tuples of form (id, topic summary, count).
                                   Otherwise, a list of topic summaries.
            Returns:
              List of topic summaries if  show_count is False
              Dictionary where key is topic ID and value is a tuple of form (topic summary, count) if show_count is True
            self._check_model()
            feature_names = self.vectorizer.get_feature_names_out()
            topic_summaries = []
            for topic_idx, topic in enumerate(self.model.components_):
                summary = [feature_names[i] for i in topic.argsort()[: -n_words - 1 : -1]]
                if as_string:
                    summary = " ".join(summary)
                topic_summaries.append(summary)
            if show_counts:
                self._check_build()
                topic_counts = sorted(
                    [(k, topic_summaries[k], len(v)) for k, v in self.topic_dict.items()],
                    key=lambda kv: kv[-1],
                    reverse=True,
                return dict((t[0], t[1:]) for t in topic_counts)
            return topic_summaries
        def print_topics(self, n_words=10, show_counts=False):
            print topics
            n_words(int): number of words to describe each topic
            show_counts(bool): If True, print topics with document counts, where
                               the count is the number of documents with that topic as primary.
            topics = self.get_topics(n_words=n_words, as_string=True)
            if show_counts:
                self._check_build()
                topic_counts = sorted(
                    [(k, topics[k], len(v)) for k, v in self.topic_dict.items()],
                    key=lambda kv: kv[-1],
                    reverse=True,
                for idx, topic, count in topic_counts:
                    print("topic:%s | count:%s | %s" % (idx, count, topic))
            else:
                for i, t in enumerate(topics):
                    print("topic %s | %s" % (i, t))
            return
        def build(self, texts, threshold=None):
            Builds the document-topic distribution showing the topic probability distirbution
            for each document in <texts> with respect to the learned topic space.
            Args:
                texts (list of str): list of text documents
                threshold (float): If not None, documents with whose highest topic probability
                                   is less than threshold are filtered out.
            if threshold is not None:
                doc_topics, bool_array = self.predict(texts, threshold=threshold)
            else:
                doc_topics = self.predict(texts)
                bool_array = np.array([True] * len(texts))
            self.doc_topics = doc_topics
            self.bool_array = bool_array
            texts = [text for i, text in enumerate(texts) if bool_array[i]]
            self.topic_dict = self._rank_documents(texts, doc_topics=doc_topics)
            return
        def filter(self, obj):
            The build method may prune documents based on threshold.
            This method prunes other lists based on how build pruned documents.
            This is useful to filter lists containing metadata associated with documents
            for use with visualize_documents.
            Args:
                obj(list|np.ndarray|pandas.DataFrame):a list, numpy array, or DataFrame of data
            Returns:
                filtered obj
            length = (
                obj.shape[0] if isinstance(obj, (pd.DataFrame, np.ndarray)) else len(obj)
            if length != self.bool_array.shape[0]:
                raise ValueError(
                    "Length of obj is not consistent with the number of documents "
                    + "supplied to get_topic_model"
            obj = np.array(obj) if isinstance(obj, list) else obj
            return obj[self.bool_array]
        def get_docs(self, topic_ids=[], doc_ids=[], rank=False):
            Returns document entries for supplied topic_ids.
            Documents returned are those whose primary topic is topic with given topic_id
            Args:
                topic_ids(list of ints): list of topid IDs where each id is in the range
                                         of range(self.n_topics).
                doc_ids (list of ints): list of document IDs where each id is an index
                                        into self.doctopics
                rank(bool): If True, the list is sorted first by topic_id (ascending)
                            and then ty topic probability (descending).
                            Otherwise, list is sorted by doc_id (i.e., the order
                            of texts supplied to self.build (which is the order of self.doc_topics).
            Returns:
                list of dicts:  list of dicts with keys:
                                'text': text of document
                                'doc_id': ID of document
                                'topic_proba': topic probability (or score)
                                'topic_id': ID of topic
            self._check_build()
            if not topic_ids:
                topic_ids = list(range(self.n_topics))
            result_texts = []
            for topic_id in topic_ids:
                if topic_id not in self.topic_dict:
                    continue
                texts = [
                        "text": tup[0],
                        "doc_id": tup[1],
                        "topic_proba": tup[2],
                        "topic_id": topic_id,
                    for tup in self.topic_dict[topic_id]
                    if not doc_ids or tup[1] in doc_ids
                result_texts.extend(texts)
            if not rank:
                result_texts = sorted(result_texts, key=lambda x: x["doc_id"])
            return result_texts
        def get_doctopics(self, topic_ids=[], doc_ids=[]):
            Returns a topic probability distribution for documents
            with primary topic that is one of <topic_ids> and with doc_id in <doc_ids>.
            If no topic_ids or doc_ids are provided, then topic distributions for all documents
            are returned (which equivalent to the output of get_document_topic_distribution).
            Args:
                topic_ids(list of ints): list of topid IDs where each id is in the range
                                         of range(self.n_topics).
                doc_ids (list of ints): list of document IDs where each id is an index
                                        into self.doctopics
            Returns:
                np.ndarray: Each row is the topic probability distribution of a document.
                            Array is sorted in the order returned by self.get_docs.
            docs = self.get_docs(topic_ids=topic_ids, doc_ids=doc_ids)
            return np.array([self.doc_topics[idx] for idx in [x["doc_id"] for x in docs]])
        def get_texts(self, topic_ids=[]):
            Returns texts for documents
            with primary topic that is one of <topic_ids>
            Args:
                topic_ids(list of ints): list of topic IDs
            Returns:
                list of str
            if not topic_ids:
                topic_ids = list(range(self.n_topics))
            docs = self.get_docs(topic_ids)
            return [x[0] for x in docs]
        def predict(self, texts, threshold=None, harden=False):
            Args:
                texts (list of str): list of texts
                threshold (float): If not None, documents with maximum topic scores
                                    less than <threshold> are filtered out
                harden(bool): If True, each document is assigned to a single topic for which
                              it has the highest score
            Returns:
                if threshold is None:
                    np.ndarray: topic distribution for each text document
                else:
                    (np.ndarray, np.ndarray): topic distribution and boolean array
            self._check_model()
            transformed_texts = self.vectorizer.transform(texts)
            X_topics = self.model.transform(transformed_texts)
            # if self.model_type == 'nmf':
            # scores = np.matrix(X_topics)
            # scores_normalized= scores/scores.sum(axis=1)
            # X_topics = scores_normalized
            _idx = np.array([True] * len(texts))
            if threshold is not None:
                _idx = (
                    np.amax(X_topics, axis=1) > threshold
                )  # idx of doc that above the threshold
                _idx = np.array(_idx)
                X_topics = X_topics[_idx]
            if harden:
                X_topics = self._harden_topics(X_topics)
            if threshold is not None:
                return (X_topics, _idx)
            else:
                return X_topics
        def visualize_documents(
            self,
            texts=None,
            doc_topics=None,
            width=700,
            height=700,
            point_size=5,
            title="Document Visualization",
            extra_info={},
            colors=None,
            filepath=None,
            Generates a visualization of a set of documents based on model.
            If <texts> is supplied, raw documents will be first transformed into document-topic
            matrix.  If <doc_topics> is supplied, then this will be used for visualization instead.
            Args:
                texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
                doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                     Mutually-exclusive with <texts>.
                width(int): width of image
                height(int): height of image
                point_size(int): size of circles in plot
                title(str):  title of visualization
                extra_info(dict of lists): A user-supplied information for each datapoint (attributes of the datapoint).
                                           The keys are field names.  The values are lists - each of which must
                                           be the same number of elements as <texts> or <doc_topics>. These fields are displayed
                                           when hovering over datapoints in the visualization.
                colors(list of str):  list of Hex color codes for each datapoint.
                                      Length of list must match either len(texts) or doc_topics.shape[0]
                filepath(str):             Optional filepath to save the interactive visualization
            # error-checking
            if texts is not None:
                length = len(texts)
            else:
                length = doc_topics.shape[0]
            if colors is not None and len(colors) != length:
                raise ValueError(
                    "length of colors is not consistent with length of texts or doctopics"
            if texts is not None and doc_topics is not None:
                raise ValueError("texts is mutually-exclusive with doc_topics")
            if texts is None and doc_topics is None:
                raise ValueError("One of texts or doc_topics is required.")
            if extra_info:
                invalid_keys = ["x", "y", "topic", "fill_color"]
                for k in extra_info.keys():
                    if k in invalid_keys:
                        raise ValueError('cannot use "%s" as key in extra_info' % (k))
                    lst = extra_info[k]
                    if len(lst) != length:
                        raise ValueError("texts and extra_info lists must be same size")
            # check fo bokeh
                import bokeh.plotting as bp
                from bokeh.io import output_notebook
                from bokeh.models import HoverTool
                from bokeh.plotting import save
            except:
                warnings.warn(
                    "visualize_documents method requires bokeh package: pip install bokeh"
                return
            # prepare data
            if doc_topics is not None:
                X_topics = doc_topics
            else:
                if self.verbose:
                    print("transforming texts...", end="")
                X_topics = self.predict(texts, harden=False)
                if self.verbose:
                    print("done.")
            # reduce to 2-D
            if self.verbose:
                print("reducing to 2 dimensions...", end="")
            tsne_model = TSNE(
                n_components=2, verbose=self.verbose, random_state=0, angle=0.99, init="pca"
            tsne_lda = tsne_model.fit_transform(X_topics)
            print("done.")
            # get random colormap
            colormap = U.get_random_colors(self.n_topics)
            # generate inline visualization in Jupyter notebook
            lda_keys = self._harden_topics(X_topics)
            if colors is None:
                colors = colormap[lda_keys]
            topic_summaries = self.get_topics(n_words=5)
            os.environ["BOKEH_RESOURCES"] = "inline"
            output_notebook()
            dct = {
                "x": tsne_lda[:, 0],
                "y": tsne_lda[:, 1],
                "topic": [topic_summaries[tid] for tid in lda_keys],
                "fill_color": colors,
            tool_tups = [("index", "$index"), ("(x,y)", "($x,$y)"), ("topic", "@topic")]
            for k in extra_info.keys():
                dct[k] = extra_info[k]
                tool_tups.append((k, "@" + k))
            source = bp.ColumnDataSource(data=dct)
            hover = HoverTool(tooltips=tool_tups)
            p = bp.figure(
                width=width,
                height=height,
                tools=[hover, "save", "pan", "wheel_zoom", "box_zoom", "reset"],
                # tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave",
                title=title,
            # plot_lda = bp.figure(plot_width=1400, plot_height=1100,
            # title=title,
            # tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave",
            # x_axis_type=None, y_axis_type=None, min_border=1)
            p.circle("x", "y", size=point_size, source=source, fill_color="fill_color")
            bp.show(p)
            if filepath is not None:
                bp.output_file(filepath)
                bp.save(p)
            return
        def train_recommender(self, n_neighbors=20, metric="minkowski", p=2):
            Trains a recommender that, given a single document, will return
            documents in the corpus that are semantically similar to it.
            Args:
                n_neighbors (int):
            Returns:
            from sklearn.neighbors import NearestNeighbors
            rec = NearestNeighbors(n_neighbors=n_neighbors, metric=metric, p=p)
            probs = self.get_doctopics()
            rec.fit(probs)
            self.recommender = rec
            return
        def recommend(self, text=None, doc_topic=None, n=5, n_neighbors=100):
            Given an example document, recommends documents similar to it
            from the set of documents supplied to build().
            Args:
                texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
                doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                     Mutually-exclusive with <texts>.
                n (int): number of recommendations to return
            Returns:
                list of tuples: each tuple is of the form:
                                (text, doc_id, topic_probability, topic_id)
            # error-checks
            if text is not None and doc_topic is not None:
                raise ValueError("text is mutually-exclusive with doc_topic")
            if text is None and doc_topic is None:
                raise ValueError("One of text or doc_topic is required.")
            if text is not None and type(text) not in [str]:
                raise ValueError("text must be a str ")
            if doc_topic is not None and type(doc_topic) not in [np.ndarray]:
                raise ValueError("doc_topic must be a np.ndarray")
            if n > n_neighbors:
                n_neighbors = n
            x_test = [doc_topic]
            if text:
                x_test = self.predict([text])
            docs = self.get_docs()
            indices = self.recommender.kneighbors(
                x_test, return_distance=False, n_neighbors=n_neighbors
            results = [doc for i, doc in enumerate(docs) if i in indices]
            return results[:n]
        def train_scorer(self, topic_ids=[], doc_ids=[], n_neighbors=20):
            Trains a scorer that can score documents based on similarity to a
            seed set of documents represented by topic_ids and doc_ids.
            NOTE: The score method currently employs the use of LocalOutLierFactor, which
            means you should not try to score documents that were used in training. Only
            new, unseen documents should be scored for similarity.
            REFERENCE:
            https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor
            Args:
                topic_ids(list of ints): list of topid IDs where each id is in the range
                                         of range(self.n_topics).  Documents associated
                                         with these topic_ids will be used as seed set.
                doc_ids (list of ints): list of document IDs where each id is an index
                                        into self.doctopics.  Documents associated
                                        with these doc_ids will be used as seed set.
            Returns:
            from sklearn.neighbors import LocalOutlierFactor
            clf = LocalOutlierFactor(
                n_neighbors=n_neighbors, novelty=True, contamination=0.1
            probs = self.get_doctopics(topic_ids=topic_ids, doc_ids=doc_ids)
            clf.fit(probs)
            self.scorer = clf
            return
        def score(self, texts=None, doc_topics=None):
            Given a new set of documents (supplied as texts or doc_topics), the score method
            uses a One-Class classifier to score documents based on similarity to a
            seed set of documents (where seed set is computed by train_scorer() method).
            Higher scores indicate a higher degree of similarity.
            Positive values represent a binary decision of similar.
            Negative values represent a binary decision of dissimlar.
            In practice, negative scores closer to zer will also be simlar as One-Class
            classifiers are more strict than traditional binary classifiers.
            Documents with negative scores closer to zero are good candidates for
            inclusion in a training set for binary classification (e.g., active labeling).
            NOTE: The score method currently employs the use of LocalOutLierFactor, which
            means you should not try to score documents that were used in training. Only
            new, unseen documents should be scored for similarity.
            Args:
                texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
                doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                     Mutually-exclusive with <texts>.
            Returns:
                list of floats:  larger values indicate higher degree of similarity
                                 positive values indicate a binary decision of similar
                                 negative values indicate binary decision of dissimilar
                                 In practice, negative scores closer to zero will also
                                 be similar as One-class classifiers are more strict
                                 than traditional binary classifiers.
            # error-checks
            if texts is not None and doc_topics is not None:
                raise ValueError("texts is mutually-exclusive with doc_topics")
            if texts is None and doc_topics is None:
                raise ValueError("One of texts or doc_topics is required.")
            if texts is not None and type(texts) not in [list, np.ndarray]:
                raise ValueError("texts must be either a list or numpy ndarray")
            if doc_topics is not None and type(doc_topics) not in [np.ndarray]:
                raise ValueError("doc_topics must be a np.ndarray")
            x_test = doc_topics
            if texts:
                x_test = self.predict(texts)
            return self.scorer.decision_function(x_test)
        def search(self, query, topic_ids=[], doc_ids=[], case_sensitive=False):
            search documents for query string.
            Args:
                query(str):  the word or phrase to search
                topic_ids(list of ints): list of topid IDs where each id is in the range
                                         of range(self.n_topics).
                doc_ids (list of ints): list of document IDs where each id is an index
                                        into self.doctopics
                case_sensitive(bool):  If True, case sensitive search
            # setup pattern
            if not case_sensitive:
                query = query.lower()
            pattern = re.compile(r"\b%s\b" % query)
            # retrive docs
            docs = self.get_docs(topic_ids=topic_ids, doc_ids=doc_ids)
            # search
            mb = master_bar(range(1))
            results = []
            for i in mb:
                for doc in progress_bar(docs, parent=mb):
                    text = doc["text"]
                    if not case_sensitive:
                        text = text.lower()
                    matches = pattern.findall(text)
                    if matches:
                        results.append(doc)
                if self.verbose:
                    mb.write("done.")
            return results
        def _rank_documents(self, texts, doc_topics=None):
            Rank documents by topic score.
            If topic_index is supplied, rank documents based on relevance to supplied topic.
            Otherwise, rank all texts by their highest topic score (for any topic).
            Args:
                texts(list of str): list of document texts.
                doc_topics(ndarray): pre-computed topic distribution for each document
                                     If None, re-computed from texts.
            Returns:
                dict of lists: each element in list is a tuple of (doc_index, topic_index, score)
                ... where doc_index is an index into either texts
            if doc_topics is not None:
                X_topics = doc_topics
            else:
                if self.verbose:
                    print("transforming texts to topic space...")
                X_topics = self.predict(texts)
            topics = np.argmax(X_topics, axis=1)
            scores = np.amax(X_topics, axis=1)
            doc_ids = np.array([i for i, x in enumerate(texts)])
            result = list(zip(texts, doc_ids, topics, scores))
            if self.verbose:
                print("done.")
            result = sorted(result, key=lambda x: x[-1], reverse=True)
            result_dict = {}
            for r in result:
                text = r[0]
                doc_id = r[1]
                topic_id = r[2]
                score = r[3]
                lst = result_dict.get(topic_id, [])
                lst.append((text, doc_id, score))
                result_dict[topic_id] = lst
            return result_dict
        def _harden_topics(self, X_topics):
            Transforms soft-clustering to hard-clustering
            max_topics = []
            for i in range(X_topics.shape[0]):
                max_topics.append(X_topics[i].argmax())
            X_topics = np.array(max_topics)
            return X_topics
        def _check_build(self):
            self._check_model()
            if self.topic_dict is None:
                raise Exception("Must call build() method.")
        def _check_scorer(self):
            if self.scorer is None:
                raise Exception("Must call train_scorer()")
        def _check_recommender(self):
            if self.recommender is None:
                raise Exception("Must call train_recommender()")
        def _check_model(self):
            if self.model is None or self.vectorizer is None:
                raise Exception("Must call train()")
        def save(self, fname):
            save TopicModel object
            with open(fname + ".tm_vect", "wb") as f:
                pickle.dump(self.vectorizer, f)
            with open(fname + ".tm_model", "wb") as f:
                pickle.dump(self.model, f)
            params = {
                "n_topics": self.n_topics,
                "n_features": self.n_features,
                "verbose": self.verbose,
            with open(fname + ".tm_params", "wb") as f:
                pickle.dump(params, f)
            return

    Instance variables

    var topics

    convenience method/property

    Expand source code
    @property
    def topics(self):
        convenience method/property
        return self.get_topics()

    Methods

    def build(self, texts, threshold=None)

    Builds the document-topic distribution showing the topic probability distirbution for each document in with respect to the learned topic space.

    texts : list of str
    list of text documents
    threshold : float
    If not None, documents with whose highest topic probability is less than threshold are filtered out.
    Expand source code
    def build(self, texts, threshold=None):
        Builds the document-topic distribution showing the topic probability distirbution
        for each document in <texts> with respect to the learned topic space.
        Args:
            texts (list of str): list of text documents
            threshold (float): If not None, documents with whose highest topic probability
                               is less than threshold are filtered out.
        if threshold is not None:
            doc_topics, bool_array = self.predict(texts, threshold=threshold)
        else:
            doc_topics = self.predict(texts)
            bool_array = np.array([True] * len(texts))
        self.doc_topics = doc_topics
        self.bool_array = bool_array
        texts = [text for i, text in enumerate(texts) if bool_array[i]]
        self.topic_dict = self._rank_documents(texts, doc_topics=doc_topics)
        return
    def filter(self, obj)

    The build method may prune documents based on threshold. This method prunes other lists based on how build pruned documents. This is useful to filter lists containing metadata associated with documents for use with visualize_documents.

    obj(list|np.ndarray|pandas.DataFrame):a list, numpy array, or DataFrame of data

    Returns

    filtered obj

    Expand source code
    def filter(self, obj):
        The build method may prune documents based on threshold.
        This method prunes other lists based on how build pruned documents.
        This is useful to filter lists containing metadata associated with documents
        for use with visualize_documents.
        Args:
            obj(list|np.ndarray|pandas.DataFrame):a list, numpy array, or DataFrame of data
        Returns:
            filtered obj
        length = (
            obj.shape[0] if isinstance(obj, (pd.DataFrame, np.ndarray)) else len(obj)
        if length != self.bool_array.shape[0]:
            raise ValueError(
                "Length of obj is not consistent with the number of documents "
                + "supplied to get_topic_model"
        obj = np.array(obj) if isinstance(obj, list) else obj
        return obj[self.bool_array]
    def get_docs(self, topic_ids=[], doc_ids=[], rank=False)

    Returns document entries for supplied topic_ids. Documents returned are those whose primary topic is topic with given topic_id

    topic_ids(list of ints): list of topid IDs where each id is in the range
    of range(self.n_topics).
    doc_ids : list of ints
    list of document IDs where each id is an index into self.doctopics

    rank(bool): If True, the list is sorted first by topic_id (ascending) and then ty topic probability (descending). Otherwise, list is sorted by doc_id (i.e., the order of texts supplied to self.build (which is the order of self.doc_topics).

    Returns

    list of dicts
    list of dicts with keys: 'text': text of document 'doc_id': ID of document 'topic_proba': topic probability (or score) 'topic_id': ID of topic
    Expand source code
    def get_docs(self, topic_ids=[], doc_ids=[], rank=False):
        Returns document entries for supplied topic_ids.
        Documents returned are those whose primary topic is topic with given topic_id
        Args:
            topic_ids(list of ints): list of topid IDs where each id is in the range
                                     of range(self.n_topics).
            doc_ids (list of ints): list of document IDs where each id is an index
                                    into self.doctopics
            rank(bool): If True, the list is sorted first by topic_id (ascending)
                        and then ty topic probability (descending).
                        Otherwise, list is sorted by doc_id (i.e., the order
                        of texts supplied to self.build (which is the order of self.doc_topics).
        Returns:
            list of dicts:  list of dicts with keys:
                            'text': text of document
                            'doc_id': ID of document
                            'topic_proba': topic probability (or score)
                            'topic_id': ID of topic
        self._check_build()
        if not topic_ids:
            topic_ids = list(range(self.n_topics))
        result_texts = []
        for topic_id in topic_ids:
            if topic_id not in self.topic_dict:
                continue
            texts = [
                    "text": tup[0],
                    "doc_id": tup[1],
                    "topic_proba": tup[2],
                    "topic_id": topic_id,
                for tup in self.topic_dict[topic_id]
                if not doc_ids or tup[1] in doc_ids
            result_texts.extend(texts)
        if not rank:
            result_texts = sorted(result_texts, key=lambda x: x["doc_id"])
        return result_texts
    def get_doctopics(self, topic_ids=[], doc_ids=[])

    Returns a topic probability distribution for documents with primary topic that is one of and with doc_id in .

    If no topic_ids or doc_ids are provided, then topic distributions for all documents are returned (which equivalent to the output of get_document_topic_distribution).

    topic_ids(list of ints): list of topid IDs where each id is in the range
    of range(self.n_topics).
    doc_ids : list of ints
    list of document IDs where each id is an index into self.doctopics

    Returns

    np.ndarray
    Each row is the topic probability distribution of a document. Array is sorted in the order returned by self.get_docs.
    Expand source code
    def get_doctopics(self, topic_ids=[], doc_ids=[]):
        Returns a topic probability distribution for documents
        with primary topic that is one of <topic_ids> and with doc_id in <doc_ids>.
        If no topic_ids or doc_ids are provided, then topic distributions for all documents
        are returned (which equivalent to the output of get_document_topic_distribution).
        Args:
            topic_ids(list of ints): list of topid IDs where each id is in the range
                                     of range(self.n_topics).
            doc_ids (list of ints): list of document IDs where each id is an index
                                    into self.doctopics
        Returns:
            np.ndarray: Each row is the topic probability distribution of a document.
                        Array is sorted in the order returned by self.get_docs.
        docs = self.get_docs(topic_ids=topic_ids, doc_ids=doc_ids)
        return np.array([self.doc_topics[idx] for idx in [x["doc_id"] for x in docs]])
    def get_document_topic_distribution(self)

    Gets the document-topic distribution. Each row is a document and each column is a topic The output of this method is equivalent to invoking get_doctopics with no arguments.

    Expand source code
    def get_document_topic_distribution(self):
        Gets the document-topic distribution.
        Each row is a document and each column is a topic
        The output of this method is equivalent to invoking get_doctopics with no arguments.
        self._check_build()
        return self.doc_topics
    def get_sorted_docs(self, topic_id)

    Returns all docs sorted by relevance to . Unlike get_docs, this ranks documents by the supplied topic_id rather than the topic_id to which document is most relevant.

    Expand source code
    def get_sorted_docs(self, topic_id):
        Returns all docs sorted by relevance to <topic_id>.
        Unlike get_docs, this ranks documents by the supplied topic_id rather
        than the topic_id to which document is most relevant.
        docs = self.get_docs()
        d = {}
        for doc in docs:
            d[doc["doc_id"]] = doc
        m = self.get_document_topic_distribution()
        doc_ids = (-m[:, topic_id]).argsort()
        return [d[doc_id] for doc_id in doc_ids]
    def get_texts(self, topic_ids=[])

    Returns texts for documents with primary topic that is one of

    topic_ids(list of ints): list of topic IDs

    Returns

    list of str

    Expand source code
    def get_texts(self, topic_ids=[]):
        Returns texts for documents
        with primary topic that is one of <topic_ids>
        Args:
            topic_ids(list of ints): list of topic IDs
        Returns:
            list of str
        if not topic_ids:
            topic_ids = list(range(self.n_topics))
        docs = self.get_docs(topic_ids)
        return [x[0] for x in docs]
    def get_topics(self, n_words=10, as_string=True, show_counts=False)

    Returns a list of discovered topics

    n_words(int): number of words to use in topic summary as_string(bool): If True, each summary is a space-delimited string instead of list of words show_counts(bool): If True, returns list of tuples of form (id, topic summary, count). Otherwise, a list of topic summaries.

    Returns

    List of topic summaries if show_count is False Dictionary where key is topic ID and value is a tuple of form (topic summary, count) if show_count is True

    Expand source code
    def get_topics(self, n_words=10, as_string=True, show_counts=False):
        Returns a list of discovered topics
        Args:
            n_words(int): number of words to use in topic summary
            as_string(bool): If True, each summary is a space-delimited string instead of list of words
            show_counts(bool): If True, returns list of tuples of form (id, topic summary, count).
                               Otherwise, a list of topic summaries.
        Returns:
          List of topic summaries if  show_count is False
          Dictionary where key is topic ID and value is a tuple of form (topic summary, count) if show_count is True
        self._check_model()
        feature_names = self.vectorizer.get_feature_names_out()
        topic_summaries = []
        for topic_idx, topic in enumerate(self.model.components_):
            summary = [feature_names[i] for i in topic.argsort()[: -n_words - 1 : -1]]
            if as_string:
                summary = " ".join(summary)
            topic_summaries.append(summary)
        if show_counts:
            self._check_build()
            topic_counts = sorted(
                [(k, topic_summaries[k], len(v)) for k, v in self.topic_dict.items()],
                key=lambda kv: kv[-1],
                reverse=True,
            return dict((t[0], t[1:]) for t in topic_counts)
        return topic_summaries
    def get_word_weights(self, topic_id, n_words=100)

    Returns a list tuples of the form: (word, weight) for given topic_id. The weight can be interpreted as the number of times word was assigned to topic with given topic_id. REFERENCE: https://stackoverflow.com/a/48890889/13550699

    topic_id(int): topic ID n_words=int): number of top words

    Expand source code
    def get_word_weights(self, topic_id, n_words=100):
        Returns a list tuples of the form: (word, weight) for given topic_id.
        The weight can be interpreted as the number of times word was assigned to topic with given topic_id.
        REFERENCE: https://stackoverflow.com/a/48890889/13550699
        Args:
            topic_id(int): topic ID
            n_words=int): number of top words
        self._check_model()
        if topic_id + 1 > len(self.model.components_):
            raise ValueError(
                "topic_id must be less than %s" % (len(self.model.components_))
        feature_names = self.vectorizer.get_feature_names_out()
        word_probs = self.model.components_[topic_id]
        word_ids = [i for i in word_probs.argsort()[: -n_words - 1 : -1]]
        words = [feature_names[i] for i in word_ids]
        probs = [word_probs[i] for i in word_ids]
        return list(zip(words, probs))
    def predict(self, texts, threshold=None, harden=False)
    texts : list of str
    list of texts
    threshold : float
    If not None, documents with maximum topic scores less than are filtered out

    harden(bool): If True, each document is assigned to a single topic for which it has the highest score

    Returns

    if threshold is None:
    np.ndarray
    topic distribution for each text document

    else: (np.ndarray, np.ndarray): topic distribution and boolean array

    Expand source code
    def predict(self, texts, threshold=None, harden=False):
        Args:
            texts (list of str): list of texts
            threshold (float): If not None, documents with maximum topic scores
                                less than <threshold> are filtered out
            harden(bool): If True, each document is assigned to a single topic for which
                          it has the highest score
        Returns:
            if threshold is None:
                np.ndarray: topic distribution for each text document
            else:
                (np.ndarray, np.ndarray): topic distribution and boolean array
        self._check_model()
        transformed_texts = self.vectorizer.transform(texts)
        X_topics = self.model.transform(transformed_texts)
        # if self.model_type == 'nmf':
        # scores = np.matrix(X_topics)
        # scores_normalized= scores/scores.sum(axis=1)
        # X_topics = scores_normalized
        _idx = np.array([True] * len(texts))
        if threshold is not None:
            _idx = (
                np.amax(X_topics, axis=1) > threshold
            )  # idx of doc that above the threshold
            _idx = np.array(_idx)
            X_topics = X_topics[_idx]
        if harden:
            X_topics = self._harden_topics(X_topics)
        if threshold is not None:
            return (X_topics, _idx)
        else:
            return X_topics
    def print_topics(self, n_words=10, show_counts=False)

    print topics n_words(int): number of words to describe each topic show_counts(bool): If True, print topics with document counts, where the count is the number of documents with that topic as primary.

    Expand source code
    def print_topics(self, n_words=10, show_counts=False):
        print topics
        n_words(int): number of words to describe each topic
        show_counts(bool): If True, print topics with document counts, where
                           the count is the number of documents with that topic as primary.
        topics = self.get_topics(n_words=n_words, as_string=True)
        if show_counts:
            self._check_build()
            topic_counts = sorted(
                [(k, topics[k], len(v)) for k, v in self.topic_dict.items()],
                key=lambda kv: kv[-1],
                reverse=True,
            for idx, topic, count in topic_counts:
                print("topic:%s | count:%s | %s" % (idx, count, topic))
        else:
            for i, t in enumerate(topics):
                print("topic %s | %s" % (i, t))
        return
    def recommend(self, text=None, doc_topic=None, n=5, n_neighbors=100)

    Given an example document, recommends documents similar to it from the set of documents supplied to build().

    texts(list of str): list of document texts. Mutually-exclusive with
    doc_topics(ndarray): pre-computed topic distribution for each document in texts.
    Mutually-exclusive with .
    n : int
    number of recommendations to return

    Returns

    list of tuples
    each tuple is of the form: (text, doc_id, topic_probability, topic_id)
    Expand source code
    def recommend(self, text=None, doc_topic=None, n=5, n_neighbors=100):
        Given an example document, recommends documents similar to it
        from the set of documents supplied to build().
        Args:
            texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
            doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                 Mutually-exclusive with <texts>.
            n (int): number of recommendations to return
        Returns:
            list of tuples: each tuple is of the form:
                            (text, doc_id, topic_probability, topic_id)
        # error-checks
        if text is not None and doc_topic is not None:
            raise ValueError("text is mutually-exclusive with doc_topic")
        if text is None and doc_topic is None:
            raise ValueError("One of text or doc_topic is required.")
        if text is not None and type(text) not in [str]:
            raise ValueError("text must be a str ")
        if doc_topic is not None and type(doc_topic) not in [np.ndarray]:
            raise ValueError("doc_topic must be a np.ndarray")
        if n > n_neighbors:
            n_neighbors = n
        x_test = [doc_topic]
        if text:
            x_test = self.predict([text])
        docs = self.get_docs()
        indices = self.recommender.kneighbors(
            x_test, return_distance=False, n_neighbors=n_neighbors
        results = [doc for i, doc in enumerate(docs) if i in indices]
        return results[:n]
    def save(self, fname)

    save TopicModel object

    Expand source code
    def save(self, fname):
        save TopicModel object
        with open(fname + ".tm_vect", "wb") as f:
            pickle.dump(self.vectorizer, f)
        with open(fname + ".tm_model", "wb") as f:
            pickle.dump(self.model, f)
        params = {
            "n_topics": self.n_topics,
            "n_features": self.n_features,
            "verbose": self.verbose,
        with open(fname + ".tm_params", "wb") as f:
            pickle.dump(params, f)
        return
    def score(self, texts=None, doc_topics=None)

    Given a new set of documents (supplied as texts or doc_topics), the score method uses a One-Class classifier to score documents based on similarity to a seed set of documents (where seed set is computed by train_scorer() method).

    Higher scores indicate a higher degree of similarity. Positive values represent a binary decision of similar. Negative values represent a binary decision of dissimlar. In practice, negative scores closer to zer will also be simlar as One-Class classifiers are more strict than traditional binary classifiers. Documents with negative scores closer to zero are good candidates for inclusion in a training set for binary classification (e.g., active labeling).

    NOTE: The score method currently employs the use of LocalOutLierFactor, which means you should not try to score documents that were used in training. Only new, unseen documents should be scored for similarity.

    texts(list of str): list of document texts. Mutually-exclusive with doc_topics(ndarray): pre-computed topic distribution for each document in texts. Mutually-exclusive with .

    Returns

    list of floats
    larger values indicate higher degree of similarity positive values indicate a binary decision of similar negative values indicate binary decision of dissimilar In practice, negative scores closer to zero will also be similar as One-class classifiers are more strict than traditional binary classifiers.
    Expand source code
    def score(self, texts=None, doc_topics=None):
        Given a new set of documents (supplied as texts or doc_topics), the score method
        uses a One-Class classifier to score documents based on similarity to a
        seed set of documents (where seed set is computed by train_scorer() method).
        Higher scores indicate a higher degree of similarity.
        Positive values represent a binary decision of similar.
        Negative values represent a binary decision of dissimlar.
        In practice, negative scores closer to zer will also be simlar as One-Class
        classifiers are more strict than traditional binary classifiers.
        Documents with negative scores closer to zero are good candidates for
        inclusion in a training set for binary classification (e.g., active labeling).
        NOTE: The score method currently employs the use of LocalOutLierFactor, which
        means you should not try to score documents that were used in training. Only
        new, unseen documents should be scored for similarity.
        Args:
            texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
            doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                 Mutually-exclusive with <texts>.
        Returns:
            list of floats:  larger values indicate higher degree of similarity
                             positive values indicate a binary decision of similar
                             negative values indicate binary decision of dissimilar
                             In practice, negative scores closer to zero will also
                             be similar as One-class classifiers are more strict
                             than traditional binary classifiers.
        # error-checks
        if texts is not None and doc_topics is not None:
            raise ValueError("texts is mutually-exclusive with doc_topics")
        if texts is None and doc_topics is None:
            raise ValueError("One of texts or doc_topics is required.")
        if texts is not None and type(texts) not in [list, np.ndarray]:
            raise ValueError("texts must be either a list or numpy ndarray")
        if doc_topics is not None and type(doc_topics) not in [np.ndarray]:
            raise ValueError("doc_topics must be a np.ndarray")
        x_test = doc_topics
        if texts:
            x_test = self.predict(texts)
        return self.scorer.decision_function(x_test)
    def search(self, query, topic_ids=[], doc_ids=[], case_sensitive=False)

    search documents for query string.

    query(str): the word or phrase to search
    topic_ids(list of ints): list of topid IDs where each id is in the range
    of range(self.n_topics).
    doc_ids : list of ints
    list of document IDs where each id is an index into self.doctopics

    case_sensitive(bool): If True, case sensitive search

    Expand source code
    def search(self, query, topic_ids=[], doc_ids=[], case_sensitive=False):
        search documents for query string.
        Args:
            query(str):  the word or phrase to search
            topic_ids(list of ints): list of topid IDs where each id is in the range
                                     of range(self.n_topics).
            doc_ids (list of ints): list of document IDs where each id is an index
                                    into self.doctopics
            case_sensitive(bool):  If True, case sensitive search
        # setup pattern
        if not case_sensitive:
            query = query.lower()
        pattern = re.compile(r"\b%s\b" % query)
        # retrive docs
        docs = self.get_docs(topic_ids=topic_ids, doc_ids=doc_ids)
        # search
        mb = master_bar(range(1))
        results = []
        for i in mb:
            for doc in progress_bar(docs, parent=mb):
                text = doc["text"]
                if not case_sensitive:
                    text = text.lower()
                matches = pattern.findall(text)
                if matches:
                    results.append(doc)
            if self.verbose:
                mb.write("done.")
        return results
    def train(self, texts, model_type='lda', n_topics=None, n_features=10000, min_df=5, max_df=0.5, stop_words='english', max_iter=5, lda_max_iter=None, lda_mode='online', token_pattern=None, hyperparam_kwargs=None)

    Fits a topic model to documents in .

    Example

    tm = ktrain.text.get_topic_model(docs, n_topics=20, n_features=1000, min_df=2, max_df=0.95)

    texts : list of str
    list of texts
    n_topics : int
    number of topics. If None, n_topics = min{400, sqrt[# documents/2]})
    n_features : int
    maximum words to consider
    max_df : float
    words in more than max_df proportion of docs discarded
    stop_words : str or list
    either 'english' for built-in stop words or a list of stop words to ignore
    max_iter : int
    maximum iterations for 'lda'. 5 is default if using lda_mode='online'.
    lda_max_iter : int
    alias for max_iter for backwards compatibility If lda_mode='batch', this should be increased (e.g., 1500). Ignored if model_type != 'lda'
    lda_mode : str
    one of {'online', 'batch'}. Ignored of model_type !='lda'

    token_pattern(str): regex pattern to use to tokenize documents. If None, a default tokenizer will be used hyperparam_kwargs(dict): hyperparameters for LDA/NMF Keys in this dict can be any of the following: alpha: alpha for LDA default: 5./n_topics beta: beta for LDA. default:0.01 nmf_alpha_W: alpha for NMF alpha_W (default is 0.0) nmf_alpha_H: alpha for NMF alpha_H (default is 'same') l1_ratio: l1_ratio for NMF. default: 0 ngram_range: whether to consider bigrams, trigrams. default: (1,1)

    Returns

    tuple
    (model, vectorizer)
    Expand source code
    def train(
        self,
        texts,
        model_type="lda",
        n_topics=None,
        n_features=10000,
        min_df=5,
        max_df=0.5,
        stop_words="english",
        max_iter=5,
        lda_max_iter=None,
        lda_mode="online",
        token_pattern=None,
        hyperparam_kwargs=None,
        Fits a topic model to documents in <texts>.
        Example:
            tm = ktrain.text.get_topic_model(docs, n_topics=20,
                                            n_features=1000, min_df=2, max_df=0.95)
        Args:
            texts (list of str): list of texts
            n_topics (int): number of topics.
                            If None, n_topics = min{400, sqrt[# documents/2]})
            n_features (int):  maximum words to consider
            max_df (float): words in more than max_df proportion of docs discarded
            stop_words (str or list): either 'english' for built-in stop words or
                                     a list of stop words to ignore
            max_iter (int): maximum iterations for 'lda'.  5 is default if using lda_mode='online'.
            lda_max_iter (int): alias for max_iter for backwards compatibility
                                If lda_mode='batch', this should be increased (e.g., 1500).
                                Ignored if model_type != 'lda'
            lda_mode (str):  one of {'online', 'batch'}. Ignored of model_type !='lda'
            token_pattern(str): regex pattern to use to tokenize documents.
                                If None, a default tokenizer will be used
            hyperparam_kwargs(dict): hyperparameters for LDA/NMF
                                     Keys in this dict can be any of the following:
                                         alpha: alpha for LDA  default: 5./n_topics
                                         beta: beta for LDA.  default:0.01
                                         nmf_alpha_W: alpha for NMF alpha_W (default is 0.0)
                                         nmf_alpha_H: alpha for NMF alpha_H (default is 'same')
                                         l1_ratio: l1_ratio for NMF. default: 0
                                         ngram_range:  whether to consider bigrams, trigrams. default: (1,1)
        Returns:
            tuple: (model, vectorizer)
        max_iter = lda_max_iter if lda_max_iter is not None else max_iter
        if hyperparam_kwargs is None:
            hyperparam_kwargs = {}
        alpha = hyperparam_kwargs.get("alpha", 5.0 / n_topics)
        nmf_alpha_W = hyperparam_kwargs.get("nmf_alpha_W", 0.0)
        nmf_alpha_H = hyperparam_kwargs.get("nmf_alpha_H", "same")
        beta = hyperparam_kwargs.get("beta", 0.01)
        l1_ratio = hyperparam_kwargs.get("l1_ratio", 0)
        ngram_range = hyperparam_kwargs.get("ngram_range", (1, 1))
        # adjust defaults based on language detected
        if texts is not None:
            lang = TU.detect_lang(texts)
            if lang != "en":
                stopwords = None if stop_words == "english" else stop_words
                token_pattern = (
                    r"(?u)\b\w+\b" if token_pattern is None else token_pattern
            if pp.is_nospace_lang(lang):
                text_list = []
                for t in texts:
                    text_list.append(" ".join(jieba.cut(t, HMM=False)))
                texts = text_list
            if self.verbose:
                print("lang: %s" % (lang))
        # preprocess texts
        if self.verbose:
            print("preprocessing texts...")
        if token_pattern is None:
            token_pattern = TU.DEFAULT_TOKEN_PATTERN
        # if token_pattern is None: token_pattern = r'(?u)\b\w\w+\b'
        vectorizer = CountVectorizer(
            max_df=max_df,
            min_df=min_df,
            max_features=n_features,
            stop_words=stop_words,
            token_pattern=token_pattern,
            ngram_range=ngram_range,
        x_train = vectorizer.fit_transform(texts)
        # fit model
        if self.verbose:
            print("fitting model...")
        if model_type == "lda":
            model = LatentDirichletAllocation(
                n_components=n_topics,
                max_iter=max_iter,
                learning_method=lda_mode,
                learning_offset=50.0,
                doc_topic_prior=alpha,
                topic_word_prior=beta,
                verbose=self.verbose,
                random_state=0,
        elif model_type == "nmf":
            model = NMF(
                n_components=n_topics,
                max_iter=max_iter,
                verbose=self.verbose,
                alpha_W=nmf_alpha_W,
                alpha_H=nmf_alpha_H,
                l1_ratio=l1_ratio,
                random_state=0,
        else:
            raise ValueError("unknown model type:", str(model_type))
        model.fit(x_train)
        # save model and vectorizer and hyperparameter settings
        return (model, vectorizer)
    def train_recommender(self, n_neighbors=20, metric='minkowski', p=2)

    Trains a recommender that, given a single document, will return documents in the corpus that are semantically similar to it.

    n_neighbors (int):

    Returns

    Expand source code
    def train_recommender(self, n_neighbors=20, metric="minkowski", p=2):
        Trains a recommender that, given a single document, will return
        documents in the corpus that are semantically similar to it.
        Args:
            n_neighbors (int):
        Returns:
        from sklearn.neighbors import NearestNeighbors
        rec = NearestNeighbors(n_neighbors=n_neighbors, metric=metric, p=p)
        probs = self.get_doctopics()
        rec.fit(probs)
        self.recommender = rec
        return
    def train_scorer(self, topic_ids=[], doc_ids=[], n_neighbors=20)

    Trains a scorer that can score documents based on similarity to a seed set of documents represented by topic_ids and doc_ids.

    NOTE: The score method currently employs the use of LocalOutLierFactor, which means you should not try to score documents that were used in training. Only new, unseen documents should be scored for similarity. REFERENCE: https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor

    topic_ids(list of ints): list of topid IDs where each id is in the range
    of range(self.n_topics). Documents associated
    with these topic_ids will be used as seed set.
    doc_ids : list of ints
    list of document IDs where each id is an index into self.doctopics. Documents associated with these doc_ids will be used as seed set.

    Returns

    Expand source code
    def train_scorer(self, topic_ids=[], doc_ids=[], n_neighbors=20):
        Trains a scorer that can score documents based on similarity to a
        seed set of documents represented by topic_ids and doc_ids.
        NOTE: The score method currently employs the use of LocalOutLierFactor, which
        means you should not try to score documents that were used in training. Only
        new, unseen documents should be scored for similarity.
        REFERENCE:
        https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor
        Args:
            topic_ids(list of ints): list of topid IDs where each id is in the range
                                     of range(self.n_topics).  Documents associated
                                     with these topic_ids will be used as seed set.
            doc_ids (list of ints): list of document IDs where each id is an index
                                    into self.doctopics.  Documents associated
                                    with these doc_ids will be used as seed set.
        Returns:
        from sklearn.neighbors import LocalOutlierFactor
        clf = LocalOutlierFactor(
            n_neighbors=n_neighbors, novelty=True, contamination=0.1
        probs = self.get_doctopics(topic_ids=topic_ids, doc_ids=doc_ids)
        clf.fit(probs)
        self.scorer = clf
        return
    def visualize_documents(self, texts=None, doc_topics=None, width=700, height=700, point_size=5, title='Document Visualization', extra_info={}, colors=None, filepath=None)

    Generates a visualization of a set of documents based on model. If is supplied, raw documents will be first transformed into document-topic matrix. If is supplied, then this will be used for visualization instead.

    texts(list of str): list of document texts. Mutually-exclusive with doc_topics(ndarray): pre-computed topic distribution for each document in texts. Mutually-exclusive with . width(int): width of image height(int): height of image point_size(int): size of circles in plot title(str): title of visualization extra_info(dict of lists): A user-supplied information for each datapoint (attributes of the datapoint). The keys are field names. The values are lists - each of which must be the same number of elements as or . These fields are displayed when hovering over datapoints in the visualization. colors(list of str): list of Hex color codes for each datapoint. Length of list must match either len(texts) or doc_topics.shape[0] filepath(str): Optional filepath to save the interactive visualization

    Expand source code
    def visualize_documents(
        self,
        texts=None,
        doc_topics=None,
        width=700,
        height=700,
        point_size=5,
        title="Document Visualization",
        extra_info={},
        colors=None,
        filepath=None,
        Generates a visualization of a set of documents based on model.
        If <texts> is supplied, raw documents will be first transformed into document-topic
        matrix.  If <doc_topics> is supplied, then this will be used for visualization instead.
        Args:
            texts(list of str): list of document texts.  Mutually-exclusive with <doc_topics>
            doc_topics(ndarray): pre-computed topic distribution for each document in texts.
                                 Mutually-exclusive with <texts>.
            width(int): width of image
            height(int): height of image
            point_size(int): size of circles in plot
            title(str):  title of visualization
            extra_info(dict of lists): A user-supplied information for each datapoint (attributes of the datapoint).
                                       The keys are field names.  The values are lists - each of which must
                                       be the same number of elements as <texts> or <doc_topics>. These fields are displayed
                                       when hovering over datapoints in the visualization.
            colors(list of str):  list of Hex color codes for each datapoint.
                                  Length of list must match either len(texts) or doc_topics.shape[0]
            filepath(str):             Optional filepath to save the interactive visualization
        # error-checking
        if texts is not None:
            length = len(texts)
        else:
            length = doc_topics.shape[0]
        if colors is not None and len(colors) != length:
            raise ValueError(
                "length of colors is not consistent with length of texts or doctopics"
        if texts is not None and doc_topics is not None:
            raise ValueError("texts is mutually-exclusive with doc_topics")
        if texts is None and doc_topics is None:
            raise ValueError("One of texts or doc_topics is required.")
        if extra_info:
            invalid_keys = ["x", "y", "topic", "fill_color"]
            for k in extra_info.keys():
                if k in invalid_keys:
                    raise ValueError('cannot use "%s" as key in extra_info' % (k))
                lst = extra_info[k]
                if len(lst) != length:
                    raise ValueError("texts and extra_info lists must be same size")
        # check fo bokeh
            import bokeh.plotting as bp
            from bokeh.io import output_notebook
            from bokeh.models import HoverTool
            from bokeh.plotting import save
        except:
            warnings.warn(
                "visualize_documents method requires bokeh package: pip install bokeh"
            return
        # prepare data
        if doc_topics is not None:
            X_topics = doc_topics
        else:
            if self.verbose:
                print("transforming texts...", end="")
            X_topics = self.predict(texts, harden=False)
            if self.verbose:
                print("done.")
        # reduce to 2-D
        if self.verbose:
            print("reducing to 2 dimensions...", end="")
        tsne_model = TSNE(
            n_components=2, verbose=self.verbose, random_state=0, angle=0.99, init="pca"
        tsne_lda = tsne_model.fit_transform(X_topics)
        print("done.")
        # get random colormap
        colormap = U.get_random_colors(self.n_topics)
        # generate inline visualization in Jupyter notebook
        lda_keys = self._harden_topics(X_topics)
        if colors is None:
            colors = colormap[lda_keys]
        topic_summaries = self.get_topics(n_words=5)
        os.environ["BOKEH_RESOURCES"] = "inline"
        output_notebook()
        dct = {
            "x": tsne_lda[:, 0],
            "y": tsne_lda[:, 1],
            "topic": [topic_summaries[tid] for tid in lda_keys],
            "fill_color": colors,
        tool_tups = [("index", "$index"), ("(x,y)", "($x,$y)"), ("topic", "@topic")]
        for k in extra_info.keys():
            dct[k] = extra_info[k]
            tool_tups.append((k, "@" + k))
        source = bp.ColumnDataSource(data=dct)
        hover = HoverTool(tooltips=tool_tups)
        p = bp.figure(
            width=width,
            height=height,
            tools=[hover, "save", "pan", "wheel_zoom", "box_zoom", "reset"],
            # tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave",
            title=title,
        # plot_lda = bp.figure(plot_width=1400, plot_height=1100,
        # title=title,
        # tools="pan,wheel_zoom,box_zoom,reset,hover,previewsave",
        # x_axis_type=None, y_axis_type=None, min_border=1)
        p.circle("x", "y", size=point_size, source=source, fill_color="fill_color")
        bp.show(p)
        if filepath is not None:
            bp.output_file(filepath)
            bp.save(p)
        return
    class Transformer (model_name, maxlen=128, class_names=[], classes=[], batch_size=None, use_with_learner=True)
    convenience class for text classification Hugging Face transformers
    Usage:
       t = Transformer('distilbert-base-uncased', maxlen=128, classes=['neg', 'pos'], batch_size=16)
       train_dataset = t.preprocess_train(train_texts, train_labels)
       model = t.get_classifier()
       model.fit(train_dataset)
    
    Args:
        model_name (str):  name of Hugging Face pretrained model
        maxlen (int):  sequence length
        class_names(list):  list of strings of class names (e.g., 'positive', 'negative').
                            The index position of string is the class ID.
                            Not required for:
                              - regression problems
                              - binary/multi classification problems where
                                labels in y_train/y_test are in string format.
                                In this case, classes will be populated automatically.
                                get_classes() can be called to view discovered class labels.
                            The class_names argument replaces the old classes argument.
        classes(list):  alias for class_names.  Included for backwards-compatiblity.
        use_with_learner(bool):  If False, preprocess_train and preprocess_test
                                 will return tf.Datasets for direct use with model.fit
                                 in tf.Keras.
                                 If True, preprocess_train and preprocess_test will
                                 return a ktrain TransformerDataset object for use with
                                 ktrain.get_learner.
        batch_size (int): batch_size - only required if use_with_learner=False
        convenience class for text classification Hugging Face transformers
        Usage:
           t = Transformer('distilbert-base-uncased', maxlen=128, classes=['neg', 'pos'], batch_size=16)
           train_dataset = t.preprocess_train(train_texts, train_labels)
           model = t.get_classifier()
           model.fit(train_dataset)
        def __init__(
            self,
            model_name,
            maxlen=128,
            class_names=[],
            classes=[],
            batch_size=None,
            use_with_learner=True,
            Args:
                model_name (str):  name of Hugging Face pretrained model
                maxlen (int):  sequence length
                class_names(list):  list of strings of class names (e.g., 'positive', 'negative').
                                    The index position of string is the class ID.
                                    Not required for:
                                      - regression problems
                                      - binary/multi classification problems where
                                        labels in y_train/y_test are in string format.
                                        In this case, classes will be populated automatically.
                                        get_classes() can be called to view discovered class labels.
                                    The class_names argument replaces the old classes argument.
                classes(list):  alias for class_names.  Included for backwards-compatiblity.
                use_with_learner(bool):  If False, preprocess_train and preprocess_test
                                         will return tf.Datasets for direct use with model.fit
                                         in tf.Keras.
                                         If True, preprocess_train and preprocess_test will
                                         return a ktrain TransformerDataset object for use with
                                         ktrain.get_learner.
                batch_size (int): batch_size - only required if use_with_learner=False
            multilabel = None  # force discovery of multilabel task from data in preprocess_train->set_multilabel
            class_names = self.migrate_classes(class_names, classes)
            if not use_with_learner and batch_size is None:
                raise ValueError("batch_size is required when use_with_learner=False")
            if multilabel and (class_names is None or not class_names):
                raise ValueError("classes argument is required when multilabel=True")
            super().__init__(
                model_name,
                maxlen,
                max_features=10000,
                class_names=class_names,
                multilabel=multilabel,
            self.batch_size = batch_size
            self.use_with_learner = use_with_learner
            self.lang = None
        def preprocess_train(self, texts, y=None, mode="train", verbose=1):
            Preprocess training set for A Transformer model
            Y values can be in one of the following forms:
            1) integers representing the class (index into array returned by get_classes)
               for binary and multiclass text classification.
               If labels are integers, class_names argument to Transformer constructor is required.
            2) strings representing the class (e.g., 'negative', 'positive').
               If labels are strings, class_names argument to Transformer constructor is ignored,
               as class labels will be extracted from y.
            3) multi-hot-encoded vector for multilabel text classification problems
               If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
            4) Numerical values for regression problems.
               <class_names> argument to Transformer constructor should NOT be supplied
            Args:
                texts (list of strings): text of documents
                y: labels
                mode (str):  If 'train' and prepare_for_learner=False,
                             a tf.Dataset will be returned with repeat enabled
                             for training with fit_generator
                verbose(bool): verbosity
            Returns:
              TransformerDataset if self.use_with_learner = True else tf.Dataset
            tseq = super().preprocess_train(texts, y=y, mode=mode, verbose=verbose)
            if self.use_with_learner:
                return tseq
            tseq.batch_size = self.batch_size
            train = mode == "train"
            return tseq.to_tfdataset(train=train)
        def preprocess_test(self, texts, y=None, verbose=1):
            Preprocess the validation or test set for a Transformer model
            Y values can be in one of the following forms:
            1) integers representing the class (index into array returned by get_classes)
               for binary and multiclass text classification.
               If labels are integers, class_names argument to Transformer constructor is required.
            2) strings representing the class (e.g., 'negative', 'positive').
               If labels are strings, class_names argument to Transformer constructor is ignored,
               as class labels will be extracted from y.
            3) multi-hot-encoded vector for multilabel text classification problems
               If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
            4) Numerical values for regression problems.
               <class_names> argument to Transformer constructor should NOT be supplied
            Args:
                texts (list of strings): text of documents
                y: labels
                verbose(bool): verbosity
            Returns:
                TransformerDataset if self.use_with_learner = True else tf.Dataset
            self.check_trained()
            return self.preprocess_train(texts, y=y, mode="test", verbose=verbose)

    Ancestors

  • TransformersPreprocessor
  • TextPreprocessor
  • Preprocessor
  • abc.ABC
  • Methods

    def preprocess_test(self, texts, y=None, verbose=1)
    Preprocess the validation or test set for a Transformer model
    Y values can be in one of the following forms:
    1) integers representing the class (index into array returned by get_classes)
       for binary and multiclass text classification.
       If labels are integers, class_names argument to Transformer constructor is required.
    2) strings representing the class (e.g., 'negative', 'positive').
       If labels are strings, class_names argument to Transformer constructor is ignored,
       as class labels will be extracted from y.
    3) multi-hot-encoded vector for multilabel text classification problems
       If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
    4) Numerical values for regression problems.
       <class_names> argument to Transformer constructor should NOT be supplied
    Args:
        texts (list of strings): text of documents
        y: labels
        verbose(bool): verbosity
    Returns:
        TransformerDataset if self.use_with_learner = True else tf.Dataset
    Expand source code
    
    def preprocess_test(self, texts, y=None, verbose=1):
        Preprocess the validation or test set for a Transformer model
        Y values can be in one of the following forms:
        1) integers representing the class (index into array returned by get_classes)
           for binary and multiclass text classification.
           If labels are integers, class_names argument to Transformer constructor is required.
        2) strings representing the class (e.g., 'negative', 'positive').
           If labels are strings, class_names argument to Transformer constructor is ignored,
           as class labels will be extracted from y.
        3) multi-hot-encoded vector for multilabel text classification problems
           If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
        4) Numerical values for regression problems.
           <class_names> argument to Transformer constructor should NOT be supplied
        Args:
            texts (list of strings): text of documents
            y: labels
            verbose(bool): verbosity
        Returns:
            TransformerDataset if self.use_with_learner = True else tf.Dataset
        self.check_trained()
        return self.preprocess_train(texts, y=y, mode="test", verbose=verbose)
    def preprocess_train(self, texts, y=None, mode='train', verbose=1)
    Preprocess training set for A Transformer model
    Y values can be in one of the following forms:
    1) integers representing the class (index into array returned by get_classes)
       for binary and multiclass text classification.
       If labels are integers, class_names argument to Transformer constructor is required.
    2) strings representing the class (e.g., 'negative', 'positive').
       If labels are strings, class_names argument to Transformer constructor is ignored,
       as class labels will be extracted from y.
    3) multi-hot-encoded vector for multilabel text classification problems
       If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
    4) Numerical values for regression problems.
       <class_names> argument to Transformer constructor should NOT be supplied
    Args:
        texts (list of strings): text of documents
        y: labels
        mode (str):  If 'train' and prepare_for_learner=False,
                     a tf.Dataset will be returned with repeat enabled
                     for training with fit_generator
        verbose(bool): verbosity
    Returns:
      TransformerDataset if self.use_with_learner = True else tf.Dataset
    Expand source code
    
    def preprocess_train(self, texts, y=None, mode="train", verbose=1):
        Preprocess training set for A Transformer model
        Y values can be in one of the following forms:
        1) integers representing the class (index into array returned by get_classes)
           for binary and multiclass text classification.
           If labels are integers, class_names argument to Transformer constructor is required.
        2) strings representing the class (e.g., 'negative', 'positive').
           If labels are strings, class_names argument to Transformer constructor is ignored,
           as class labels will be extracted from y.
        3) multi-hot-encoded vector for multilabel text classification problems
           If labels are multi-hot-encoded, class_names argument to Transformer constructor is requird.
        4) Numerical values for regression problems.
           <class_names> argument to Transformer constructor should NOT be supplied
        Args:
            texts (list of strings): text of documents
            y: labels
            mode (str):  If 'train' and prepare_for_learner=False,
                         a tf.Dataset will be returned with repeat enabled
                         for training with fit_generator
            verbose(bool): verbosity
        Returns:
          TransformerDataset if self.use_with_learner = True else tf.Dataset
        tseq = super().preprocess_train(texts, y=y, mode=mode, verbose=verbose)
        if self.use_with_learner:
            return tseq
        tseq.batch_size = self.batch_size
        train = mode == "train"
        return tseq.to_tfdataset(train=train)

    Inherited members

  • TransformersPreprocessor:
  • get_classifier
  • get_regression_model
  • hf_convert_examples
  • load_model_and_configure_from_data
  • print_seqlen_stats
  • seqlen_stats
  • Args:
        model_name (str):  name of Hugging Face pretrained model.
                           Choose from here: https://huggingface.co/transformers/pretrained_models.html
        layers(list): list of indexes indicating which hidden layers to use when
                      constructing the embedding (e.g., last=[-1])
    
    class TransformerEmbedding:
        def __init__(self, model_name, layers=U.DEFAULT_TRANSFORMER_LAYERS):
            Args:
                model_name (str):  name of Hugging Face pretrained model.
                                   Choose from here: https://huggingface.co/transformers/pretrained_models.html
                layers(list): list of indexes indicating which hidden layers to use when
                              constructing the embedding (e.g., last=[-1])
            self.layers = layers
            self.model_name = model_name
            if model_name.startswith("xlm-roberta"):
                self.name = "xlm_roberta"
            else:
                self.name = model_name.split("-")[0]
            self.config = AutoConfig.from_pretrained(model_name)
            self.model_type = TFAutoModel
            self.tokenizer_type = AutoTokenizer
            if "bert-base-japanese" in model_name:
                self.tokenizer_type = transformers.BertJapaneseTokenizer
            self.tokenizer = self.tokenizer_type.from_pretrained(model_name)
            self.model = self._load_pretrained(model_name)
                self.embsize = self.embed("ktrain", word_level=False).shape[
                ]  # (batch_size, embsize)
            except:
                warnings.warn("could not determine Embedding size")
            # if type(self.model).__name__ not in [
            # "TFBertModel",
            # "TFDistilBertModel",
            # "TFAlbertModel",
            # "TFRobertaModel",
            # raise ValueError(
            # "TransformerEmbedding class currently only supports BERT-style models: "
            # + "Bert, DistilBert, RoBERTa and Albert and variants like BioBERT and SciBERT\n\n"
            # + "model received: %s (%s))" % (type(self.model).__name__, model_name)
        def _load_pretrained(self, model_name):
            load pretrained model
            if self.config is not None:
                self.config.output_hidden_states = True
                    model = self.model_type.from_pretrained(model_name, config=self.config)
                except:
                    warnings.warn(
                        "Could not load a Tensorflow version of model. (If this worked before, it might be an out-of-memory issue.) "
                        + "Attempting to download/load PyTorch version as TensorFlow model using from_pt=True. You will need PyTorch installed for this."
                    model = self.model_type.from_pretrained(
                        model_name, config=self.config, from_pt=True
            else:
                model = self.model_type.from_pretrained(
                    model_name, output_hidden_states=True
            return model
        def _reconstruct_word_ids(self, offsets):
            Reverse engineer the word_ids.
            word_ids = []
            last_word_id = -1
            last_offset = (-1, -1)
            for o in offsets:
                if o == (0, 0):
                    word_ids.append(None)
                    continue
                # must test to see if start is same as last offset start due to xml-roberta quirk with tokens like 070
                if o[0] == last_offset[0] or o[0] == last_offset[1]:
                    word_ids.append(last_word_id)
                elif o[0] > last_offset[1]:
                    last_word_id += 1
                    word_ids.append(last_word_id)
                last_offset = o
            return word_ids
        def embed(
            self,
            texts,
            word_level=True,
            max_length=512,
            aggregation_strategy="first",
            layers=U.DEFAULT_TRANSFORMER_LAYERS,
            Get embedding for word, phrase, or sentence.
            Args:
              text(str|list): word, phrase, or sentence or list of them representing a batch
              word_level(bool): If True, returns embedding for each token in supplied texts.
                                If False, returns embedding for each text in texts
              max_length(int): max length of tokens
              aggregation_strategy(str): If 'first', vector of first subword is used as representation.
                                         If 'average', mean of all subword vectors is used.
              layers(list): list of indexes indicating which hidden layers to use when
                            constructing the embedding (e.g., last hidden state is [-1])
            Returns:
                np.ndarray : embeddings
            if isinstance(texts, str):
                texts = [texts]
            if not isinstance(texts[0], str):
                texts = [" ".join(text) for text in texts]
            sentences = []
            for text in texts:
                sentences.append(self.tokenizer.tokenize(text))
            maxlen = (
                        [tokens for tokens in sentences],
                        key=len,
            if max_length is not None and maxlen > max_length:
                maxlen = max_length  # added due to issue #270
            sentences = []
            all_input_ids = []
            all_input_masks = []
            all_word_ids = []
            all_offsets = []  # retained but not currently used as of v0.36.1 (#492)
            for text in texts:
                encoded = self.tokenizer.encode_plus(
                    text, max_length=maxlen, truncation=True, return_offsets_mapping=True
                input_ids = encoded["input_ids"]
                offsets = encoded["offset_mapping"]
                del encoded["offset_mapping"]
                inp = encoded["input_ids"][:]
                inp = inp[1:] if inp[0] == self.tokenizer.cls_token_id else inp
                inp = inp[:-1] if inp[-1] == self.tokenizer.sep_token_id else inp
                tokens = self.tokenizer.convert_ids_to_tokens(inp)
                if len(tokens) > maxlen - 2:
                    tokens = tokens[0 : (maxlen - 2)]
                sentences.append(tokens)
                input_mask = [1] * len(input_ids)
                while len(input_ids) < maxlen:
                    input_ids.append(0)
                    input_mask.append(0)
                all_input_ids.append(input_ids)
                all_input_masks.append(input_mask)
                # Note about Issue #492:
                # deberta includes preceding space in offfset_mapping (https://www.kaggle.com/code/junkoda/be-aware-of-white-space-deberta-roberta)
                # models like bert-base-case produce word_ids that do not correspond to whitespace tokenization (e.g.,"score 99.9%", "BRUSSELS 1996-08-22")
                # Therefore, we use offset_mappings unless the model is deberta for now.
                word_ids = (
                    encoded.word_ids()
                    if "deberta" in self.model_name
                    else self._reconstruct_word_ids(offsets)
                all_word_ids.append(word_ids)
                all_offsets.append(offsets)
            all_input_ids = np.array(all_input_ids)
            all_input_masks = np.array(all_input_masks)
            outputs = self.model(all_input_ids, attention_mask=all_input_masks)
            hidden_states = outputs[-1]  # output_hidden_states=True
            # compile raw embeddings
            if len(self.layers) == 1:
                # raw_embeddings = hidden_states[-1].numpy()
                raw_embeddings = hidden_states[self.layers[0]].numpy()
            else:
                raw_embeddings = []
                for batch_id in range(hidden_states[0].shape[0]):
                    token_embeddings = []
                    for token_id in range(hidden_states[0].shape[1]):
                        all_layers = []
                        for layer_id in self.layers:
                            all_layers.append(
                                hidden_states[layer_id][batch_id][token_id].numpy()
                        token_embeddings.append(np.concatenate(all_layers))
                    raw_embeddings.append(token_embeddings)
                raw_embeddings = np.array(raw_embeddings)
            if not word_level:  # sentence-level embedding
                return np.mean(raw_embeddings, axis=1)
            # all space-separate tokens in input should be assigned a single embedding vector
            # example: If 99.9% is a token, then it gets a single embedding.
            # example: If input is pre-tokenized (i.e., 99 . 9 %), then there are four embedding vectors
            filtered_embeddings = []
            for i in range(len(raw_embeddings)):
                filtered_embedding = []
                raw_embedding = raw_embeddings[i]
                subvectors = []
                last_word_id = -1
                for j in range(len(all_offsets[i])):
                    word_id = all_word_ids[i][j]
                    if word_id is None:
                        continue
                    if word_id == last_word_id:
                        subvectors.append(raw_embedding[j])
                    if word_id > last_word_id:
                        if len(subvectors) > 0:
                            if aggregation_strategy == "average":
                                filtered_embedding.append(np.mean(subvectors, axis=0))
                            else:
                                filtered_embedding.append(subvectors[0])
                            subvectors = []
                        subvectors.append(raw_embedding[j])
                        last_word_id = word_id
                if len(subvectors) > 0:
                    if aggregation_strategy == "average":
                        filtered_embedding.append(np.mean(subvectors, axis=0))
                    else:
                        filtered_embedding.append(subvectors[0])
                    subvectors = []
                filtered_embeddings.append(filtered_embedding)
            # pad embeddings with zeros
            max_length = max([len(e) for e in filtered_embeddings])
            embeddings = []
            for e in filtered_embeddings:
                for i in range(max_length - len(e)):
                    e.append(np.zeros((self.embsize,)))
                embeddings.append(np.array(e))
            return np.array(embeddings)

    Methods

    def embed(self, texts, word_level=True, max_length=512, aggregation_strategy='first', layers=[-2])
    Get embedding for word, phrase, or sentence.
    Args:
      text(str|list): word, phrase, or sentence or list of them representing a batch
      word_level(bool): If True, returns embedding for each token in supplied texts.
                        If False, returns embedding for each text in texts
      max_length(int): max length of tokens
      aggregation_strategy(str): If 'first', vector of first subword is used as representation.
                                 If 'average', mean of all subword vectors is used.
      layers(list): list of indexes indicating which hidden layers to use when
                    constructing the embedding (e.g., last hidden state is [-1])
    Returns:
        np.ndarray : embeddings
    Expand source code
    
    def embed(
        self,
        texts,
        word_level=True,
        max_length=512,
        aggregation_strategy="first",
        layers=U.DEFAULT_TRANSFORMER_LAYERS,
        Get embedding for word, phrase, or sentence.
        Args:
          text(str|list): word, phrase, or sentence or list of them representing a batch
          word_level(bool): If True, returns embedding for each token in supplied texts.
                            If False, returns embedding for each text in texts
          max_length(int): max length of tokens
          aggregation_strategy(str): If 'first', vector of first subword is used as representation.
                                     If 'average', mean of all subword vectors is used.
          layers(list): list of indexes indicating which hidden layers to use when
                        constructing the embedding (e.g., last hidden state is [-1])
        Returns:
            np.ndarray : embeddings
        if isinstance(texts, str):
            texts = [texts]
        if not isinstance(texts[0], str):
            texts = [" ".join(text) for text in texts]
        sentences = []
        for text in texts:
            sentences.append(self.tokenizer.tokenize(text))
        maxlen = (
                    [tokens for tokens in sentences],
                    key=len,
        if max_length is not None and maxlen > max_length:
            maxlen = max_length  # added due to issue #270
        sentences = []
        all_input_ids = []
        all_input_masks = []
        all_word_ids = []
        all_offsets = []  # retained but not currently used as of v0.36.1 (#492)
        for text in texts:
            encoded = self.tokenizer.encode_plus(
                text, max_length=maxlen, truncation=True, return_offsets_mapping=True
            input_ids = encoded["input_ids"]
            offsets = encoded["offset_mapping"]
            del encoded["offset_mapping"]
            inp = encoded["input_ids"][:]
            inp = inp[1:] if inp[0] == self.tokenizer.cls_token_id else inp
            inp = inp[:-1] if inp[-1] == self.tokenizer.sep_token_id else inp
            tokens = self.tokenizer.convert_ids_to_tokens(inp)
            if len(tokens) > maxlen - 2:
                tokens = tokens[0 : (maxlen - 2)]
            sentences.append(tokens)
            input_mask = [1] * len(input_ids)
            while len(input_ids) < maxlen:
                input_ids.append(0)
                input_mask.append(0)
            all_input_ids.append(input_ids)
            all_input_masks.append(input_mask)
            # Note about Issue #492:
            # deberta includes preceding space in offfset_mapping (https://www.kaggle.com/code/junkoda/be-aware-of-white-space-deberta-roberta)
            # models like bert-base-case produce word_ids that do not correspond to whitespace tokenization (e.g.,"score 99.9%", "BRUSSELS 1996-08-22")
            # Therefore, we use offset_mappings unless the model is deberta for now.
            word_ids = (
                encoded.word_ids()
                if "deberta" in self.model_name
                else self._reconstruct_word_ids(offsets)
            all_word_ids.append(word_ids)
            all_offsets.append(offsets)
        all_input_ids = np.array(all_input_ids)
        all_input_masks = np.array(all_input_masks)
        outputs = self.model(all_input_ids, attention_mask=all_input_masks)
        hidden_states = outputs[-1]  # output_hidden_states=True
        # compile raw embeddings
        if len(self.layers) == 1:
            # raw_embeddings = hidden_states[-1].numpy()
            raw_embeddings = hidden_states[self.layers[0]].numpy()
        else:
            raw_embeddings = []
            for batch_id in range(hidden_states[0].shape[0]):
                token_embeddings = []
                for token_id in range(hidden_states[0].shape[1]):
                    all_layers = []
                    for layer_id in self.layers:
                        all_layers.append(
                            hidden_states[layer_id][batch_id][token_id].numpy()
                    token_embeddings.append(np.concatenate(all_layers))
                raw_embeddings.append(token_embeddings)
            raw_embeddings = np.array(raw_embeddings)
        if not word_level:  # sentence-level embedding
            return np.mean(raw_embeddings, axis=1)
        # all space-separate tokens in input should be assigned a single embedding vector
        # example: If 99.9% is a token, then it gets a single embedding.
        # example: If input is pre-tokenized (i.e., 99 . 9 %), then there are four embedding vectors
        filtered_embeddings = []
        for i in range(len(raw_embeddings)):
            filtered_embedding = []
            raw_embedding = raw_embeddings[i]
            subvectors = []
            last_word_id = -1
            for j in range(len(all_offsets[i])):
                word_id = all_word_ids[i][j]
                if word_id is None:
                    continue
                if word_id == last_word_id:
                    subvectors.append(raw_embedding[j])
                if word_id > last_word_id:
                    if len(subvectors) > 0:
                        if aggregation_strategy == "average":
                            filtered_embedding.append(np.mean(subvectors, axis=0))
                        else:
                            filtered_embedding.append(subvectors[0])
                        subvectors = []
                    subvectors.append(raw_embedding[j])
                    last_word_id = word_id
            if len(subvectors) > 0:
                if aggregation_strategy == "average":
                    filtered_embedding.append(np.mean(subvectors, axis=0))
                else:
                    filtered_embedding.append(subvectors[0])
                subvectors = []
            filtered_embeddings.append(filtered_embedding)
        # pad embeddings with zeros
        max_length = max([len(e) for e in filtered_embeddings])
        embeddings = []
        for e in filtered_embeddings:
            for i in range(max_length - len(e)):
                e.append(np.zeros((self.embsize,)))
            embeddings.append(np.array(e))
        return np.array(embeddings)

    interface to Transformer-based text summarization

    interface to BART-based text summarization using transformers library
    Args:
      model_name(str): name of BART model for summarization
      device(str): device to use (e.g., 'cuda', 'cpu')
    Expand source code
    
    class TransformerSummarizer(TorchBase):
        interface to Transformer-based text summarization
        def __init__(self, model_name="facebook/bart-large-cnn", device=None):
            interface to BART-based text summarization using transformers library
            Args:
              model_name(str): name of BART model for summarization
              device(str): device to use (e.g., 'cuda', 'cpu')
            if "bart" not in model_name:
                raise ValueError("TransformerSummarizer currently only accepts BART models")
            super().__init__(device=device)
            from transformers import BartForConditionalGeneration, BartTokenizer
            self.tokenizer = BartTokenizer.from_pretrained(model_name)
            self.model = BartForConditionalGeneration.from_pretrained(model_name).to(
                self.torch_device
        def summarize(
            self,
            max_length=150,
            min_length=56,
            no_repeat_ngram_size=3,
            length_penalty=2.0,
            num_beams=4,
            **kwargs,
            Summarize document text.  Extra arguments are fed to generate method
            Args:
              doc(str): text of document
            Returns:
              str: summary text
            import torch
            with torch.no_grad():
                answers_input_ids = self.tokenizer.batch_encode_plus(
                    [doc], return_tensors="pt", truncation=True, max_length=1024
                )["input_ids"].to(self.torch_device)
                summary_ids = self.model.generate(
                    answers_input_ids,
                    num_beams=num_beams,
                    length_penalty=length_penalty,
                    max_length=max_length,
                    min_length=min_length,
                    no_repeat_ngram_size=no_repeat_ngram_size,
                    **kwargs,
                exec_sum = self.tokenizer.decode(
                    summary_ids.squeeze(), skip_special_tokens=True
            return exec_sum

    Ancestors

  • TorchBase
  • Methods

    def summarize(self, doc, max_length=150, min_length=56, no_repeat_ngram_size=3, length_penalty=2.0, num_beams=4, **kwargs)
    Summarize document text.  Extra arguments are fed to generate method
    Args:
      doc(str): text of document
    Returns:
      str: summary text
    Expand source code
    
    def summarize(
        self,
        max_length=150,
        min_length=56,
        no_repeat_ngram_size=3,
        length_penalty=2.0,
        num_beams=4,
        **kwargs,
        Summarize document text.  Extra arguments are fed to generate method
        Args:
          doc(str): text of document
        Returns:
          str: summary text
        import torch
        with torch.no_grad():
            answers_input_ids = self.tokenizer.batch_encode_plus(
                [doc], return_tensors="pt", truncation=True, max_length=1024
            )["input_ids"].to(self.torch_device)
            summary_ids = self.model.generate(
                answers_input_ids,
                num_beams=num_beams,
                length_penalty=length_penalty,
                max_length=max_length,
                min_length=min_length,
                no_repeat_ngram_size=no_repeat_ngram_size,
                **kwargs,
            exec_sum = self.tokenizer.decode(
                summary_ids.squeeze(), skip_special_tokens=True
        return exec_sum

    Inherited members

  • TorchBase:
  • quantize_model
  • Translator: basic wrapper around MarianMT model for language translation

    basic wrapper around MarianMT model for language translation
    Args:
      model_name(str): Helsinki-NLP model
      device(str): device to use (e.g., 'cuda', 'cpu')
      quantize(bool): If True, use quantization.
    Expand source code
    
    class Translator(TorchBase):
        Translator: basic wrapper around MarianMT model for language translation
        def __init__(self, model_name=None, device=None, quantize=False):
            basic wrapper around MarianMT model for language translation
            Args:
              model_name(str): Helsinki-NLP model
              device(str): device to use (e.g., 'cuda', 'cpu')
              quantize(bool): If True, use quantization.
            if "Helsinki-NLP" not in model_name:
                warnings.warn(
                    "Translator requires a Helsinki-NLP model: https://huggingface.co/Helsinki-NLP"
            super().__init__(device=device, quantize=quantize)
            from transformers import MarianMTModel, MarianTokenizer
            self.tokenizer = MarianTokenizer.from_pretrained(model_name)
            self.model = MarianMTModel.from_pretrained(model_name).to(self.torch_device)
            if quantize:
                self.model = self.quantize_model(self.model)
        def translate(self, src_text, join_with="\n", num_beams=1, early_stopping=False):
            Translate document (src_text).
            To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
            Args:
              src_text(str): source text.
                             The source text can either be a single sentence or an entire document with multiple sentences
                             and paragraphs.
                             IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                             If the input text is very large (e.g., an entire book), you should
                                             break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                             feed each chunk separately into translate to avoid out-of-memory issues.
              join_with(str):  list of translated sentences will be delimited with this character.
                               default: each sentence on separate line
              num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                              whicn means no beam search.
              early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                                     are finished per batch or not. Defaults to None.  If None, the transformers library
                                     sets this to False.
            Returns:
              str: translated text
            sentences = TU.sent_tokenize(src_text)
            tgt_sentences = self.translate_sentences(
                sentences, num_beams=num_beams, early_stopping=early_stopping
            return join_with.join(tgt_sentences)
        def translate_sentences(self, sentences, num_beams=1, early_stopping=False):
            Translate sentences using model_name as model.
            To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
            Args:
              sentences(list): list of strings representing sentences that need to be translated
                             IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                             If the input text is very large (e.g., an entire book), you should
                                             break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                             feed each chunk separately into translate to avoid out-of-memory issues.
              num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                              whicn means no beam search.
              early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                                     are finished per batch or not. Defaults to None.  If None, the transformers library
                                     sets this to False.
            Returns:
              str: translated sentences
            import torch
            with torch.no_grad():
                translated = self.model.generate(
                    **self.tokenizer.prepare_seq2seq_batch(
                        sentences, return_tensors="pt"
                    ).to(self.torch_device),
                    num_beams=num_beams,
                    early_stopping=early_stopping
                tgt_sentences = [
                    self.tokenizer.decode(t, skip_special_tokens=True) for t in translated
            return tgt_sentences

    Ancestors

  • TorchBase
  • Methods

    def translate(self, src_text, join_with='\n', num_beams=1, early_stopping=False)
    Translate document (src_text).
    To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
    Args:
      src_text(str): source text.
                     The source text can either be a single sentence or an entire document with multiple sentences
                     and paragraphs.
                     IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                     If the input text is very large (e.g., an entire book), you should
                                     break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                     feed each chunk separately into translate to avoid out-of-memory issues.
      join_with(str):  list of translated sentences will be delimited with this character.
                       default: each sentence on separate line
      num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                      whicn means no beam search.
      early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                             are finished per batch or not. Defaults to None.  If None, the transformers library
                             sets this to False.
    Returns:
      str: translated text
    Expand source code
    
    def translate(self, src_text, join_with="\n", num_beams=1, early_stopping=False):
        Translate document (src_text).
        To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
        Args:
          src_text(str): source text.
                         The source text can either be a single sentence or an entire document with multiple sentences
                         and paragraphs.
                         IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                         If the input text is very large (e.g., an entire book), you should
                                         break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                         feed each chunk separately into translate to avoid out-of-memory issues.
          join_with(str):  list of translated sentences will be delimited with this character.
                           default: each sentence on separate line
          num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                          whicn means no beam search.
          early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                                 are finished per batch or not. Defaults to None.  If None, the transformers library
                                 sets this to False.
        Returns:
          str: translated text
        sentences = TU.sent_tokenize(src_text)
        tgt_sentences = self.translate_sentences(
            sentences, num_beams=num_beams, early_stopping=early_stopping
        return join_with.join(tgt_sentences)
    def translate_sentences(self, sentences, num_beams=1, early_stopping=False)
    Translate sentences using model_name as model.
    To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
    Args:
      sentences(list): list of strings representing sentences that need to be translated
                     IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                     If the input text is very large (e.g., an entire book), you should
                                     break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                     feed each chunk separately into translate to avoid out-of-memory issues.
      num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                      whicn means no beam search.
      early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                             are finished per batch or not. Defaults to None.  If None, the transformers library
                             sets this to False.
    Returns:
      str: translated sentences
    Expand source code
    
    def translate_sentences(self, sentences, num_beams=1, early_stopping=False):
        Translate sentences using model_name as model.
        To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
        Args:
          sentences(list): list of strings representing sentences that need to be translated
                         IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
                                         If the input text is very large (e.g., an entire book), you should
                                         break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
                                         feed each chunk separately into translate to avoid out-of-memory issues.
          num_beams(int): Number of beams for beam search. Defaults to None.  If None, the transformers library defaults this to 1,
                          whicn means no beam search.
          early_stopping(bool):  Whether to stop the beam search when at least ``num_beams`` sentences
                                 are finished per batch or not. Defaults to None.  If None, the transformers library
                                 sets this to False.
        Returns:
          str: translated sentences
        import torch
        with torch.no_grad():
            translated = self.model.generate(
                **self.tokenizer.prepare_seq2seq_batch(
                    sentences, return_tensors="pt"
                ).to(self.torch_device),
                num_beams=num_beams,
                early_stopping=early_stopping
            tgt_sentences = [
                self.tokenizer.decode(t, skip_special_tokens=True) for t in translated
        return tgt_sentences

    Inherited members

  • TorchBase:
  • quantize_model
  • class ZeroShotClassifier (model_name='facebook/bart-large-mnli', device=None, quantize=False)

    interface to Zero Shot Topic Classifier

    ZeroShotClassifier constructor
    Args:
      model_name(str): name of a BART NLI model
      device(str): device to use (e.g., 'cuda', 'cpu')
      quantize(bool): If True, faster quantization will be used
    Expand source code
    
    class ZeroShotClassifier(TorchBase):
        interface to Zero Shot Topic Classifier
        def __init__(
            self, model_name="facebook/bart-large-mnli", device=None, quantize=False
            ZeroShotClassifier constructor
            Args:
              model_name(str): name of a BART NLI model
              device(str): device to use (e.g., 'cuda', 'cpu')
              quantize(bool): If True, faster quantization will be used
            if "mnli" not in model_name and "xnli" not in model_name:
                raise ValueError("ZeroShotClasifier requires an MNLI or XNLI model")
            super().__init__(device=device, quantize=quantize)
            from transformers import AutoModelForSequenceClassification, AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
                self.torch_device
            if quantize:
                self.model = self.quantize_model(self.model)
        def predict(
            self,
            docs,
            labels=[],
            include_labels=False,
            multilabel=True,
            max_length=512,
            batch_size=8,
            nli_template="This text is about {}.",
            topic_strings=[],
            This method performs zero-shot text classification using Natural Language Inference (NLI).
            Args:
              docs(list|str): text of document or list of texts
              labels(list): a list of strings representing topics of your choice
                            Example:
                              labels=['political science', 'sports', 'science']
              include_labels(bool): If True, will return topic labels along with topic probabilities
              multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1.
                                If False, scores are normalized such that probabilities sum to 1.
              max_length(int): truncate long documents to this many tokens
              batch_size(int): batch_size to use. default:8
                               Increase this value to speed up predictions - especially
                               if len(topic_strings) is large.
              nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference
              topic_strings(list): alias for labels parameter for backwards compatibility
            Returns:
              inferred probabilities or list of inferred probabilities if doc is list
            # error checks
            is_str_input = False
            if not isinstance(docs, (list, np.ndarray)):
                docs = [docs]
                is_str_input = True
            if not isinstance(docs[0], str):
                raise ValueError(
                    "docs must be string or a list of strings representing document(s)"
            if len(labels) > 0 and len(topic_strings) > 0:
                raise ValueError("labels and topic_strings are mutually exclusive")
            if not labels and not topic_strings:
                raise ValueError("labels must be a list of strings")
            if topic_strings:
                labels = topic_strings
            # convert to sequences
            sequence_pairs = []
            for premise in docs:
                sequence_pairs.extend(
                    [[premise, nli_template.format(label)] for label in labels]
            if batch_size > len(sequence_pairs):
                batch_size = len(sequence_pairs)
            if len(sequence_pairs) >= 100 and batch_size == 8:
                warnings.warn(
                    "TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions"
            num_chunks = math.ceil(len(sequence_pairs) / batch_size)
            sequence_chunks = list2chunks(sequence_pairs, n=num_chunks)
            # inference
            import torch
            with torch.no_grad():
                outputs = []
                for sequences in sequence_chunks:
                    batch = self.tokenizer.batch_encode_plus(
                        sequences,
                        return_tensors="pt",
                        max_length=max_length,
                        truncation="only_first",
                        padding=True,
                    ).to(self.torch_device)
                    logits = self.model(
                        batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        return_dict=False,
                    outputs.extend(logits.cpu().detach().numpy())
                    # entail_contradiction_logits = logits[:,[0,2]]
                    # probs = entail_contradiction_logits.softmax(dim=1)
                    # true_probs = list(probs[:,1].cpu().detach().numpy())
                    # result.extend(true_probs)
            outputs = np.array(outputs)
            outputs = outputs.reshape((len(docs), len(labels), -1))
            # process outputs
            # 2020-08-24: modified based on transformers pipeline implementation
            if multilabel:
                # softmax over the entailment vs. contradiction dim for each label independently
                entail_contr_logits = outputs[..., [0, -1]]
                scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(
                    -1, keepdims=True
                scores = scores[..., 1]
            else:
                # softmax the "entailment" logits over all candidate labels
                entail_logits = outputs[..., -1]
                scores = np.exp(entail_logits) / np.exp(entail_logits).sum(
                    -1, keepdims=True
            scores = scores.tolist()
            if include_labels:
                scores = [list(zip(labels, s)) for s in scores]
            if is_str_input:
                scores = scores[0]
            return scores

    Ancestors

  • TorchBase
  • Methods

    def predict(self, docs, labels=[], include_labels=False, multilabel=True, max_length=512, batch_size=8, nli_template='This text is about {}.', topic_strings=[])
    This method performs zero-shot text classification using Natural Language Inference (NLI).
    Args:
      docs(list|str): text of document or list of texts
      labels(list): a list of strings representing topics of your choice
                    Example:
                      labels=['political science', 'sports', 'science']
      include_labels(bool): If True, will return topic labels along with topic probabilities
      multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1.
                        If False, scores are normalized such that probabilities sum to 1.
      max_length(int): truncate long documents to this many tokens
      batch_size(int): batch_size to use. default:8
                       Increase this value to speed up predictions - especially
                       if len(topic_strings) is large.
      nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference
      topic_strings(list): alias for labels parameter for backwards compatibility
    Returns:
      inferred probabilities or list of inferred probabilities if doc is list
    Expand source code
    
    def predict(
        self,
        docs,
        labels=[],
        include_labels=False,
        multilabel=True,
        max_length=512,
        batch_size=8,
        nli_template="This text is about {}.",
        topic_strings=[],
        This method performs zero-shot text classification using Natural Language Inference (NLI).
        Args:
          docs(list|str): text of document or list of texts
          labels(list): a list of strings representing topics of your choice
                        Example:
                          labels=['political science', 'sports', 'science']
          include_labels(bool): If True, will return topic labels along with topic probabilities
          multilabel(bool): If True, labels are considered independent and multiple labels can predicted true for document and be close to 1.
                            If False, scores are normalized such that probabilities sum to 1.
          max_length(int): truncate long documents to this many tokens
          batch_size(int): batch_size to use. default:8
                           Increase this value to speed up predictions - especially
                           if len(topic_strings) is large.
          nli_template(str): labels are inserted into this template for use as hypotheses in natural language inference
          topic_strings(list): alias for labels parameter for backwards compatibility
        Returns:
          inferred probabilities or list of inferred probabilities if doc is list
        # error checks
        is_str_input = False
        if not isinstance(docs, (list, np.ndarray)):
            docs = [docs]
            is_str_input = True
        if not isinstance(docs[0], str):
            raise ValueError(
                "docs must be string or a list of strings representing document(s)"
        if len(labels) > 0 and len(topic_strings) > 0:
            raise ValueError("labels and topic_strings are mutually exclusive")
        if not labels and not topic_strings:
            raise ValueError("labels must be a list of strings")
        if topic_strings:
            labels = topic_strings
        # convert to sequences
        sequence_pairs = []
        for premise in docs:
            sequence_pairs.extend(
                [[premise, nli_template.format(label)] for label in labels]
        if batch_size > len(sequence_pairs):
            batch_size = len(sequence_pairs)
        if len(sequence_pairs) >= 100 and batch_size == 8:
            warnings.warn(
                "TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions"
        num_chunks = math.ceil(len(sequence_pairs) / batch_size)
        sequence_chunks = list2chunks(sequence_pairs, n=num_chunks)
        # inference
        import torch
        with torch.no_grad():
            outputs = []
            for sequences in sequence_chunks:
                batch = self.tokenizer.batch_encode_plus(
                    sequences,
                    return_tensors="pt",
                    max_length=max_length,
                    truncation="only_first",
                    padding=True,
                ).to(self.torch_device)
                logits = self.model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    return_dict=False,
                outputs.extend(logits.cpu().detach().numpy())
                # entail_contradiction_logits = logits[:,[0,2]]
                # probs = entail_contradiction_logits.softmax(dim=1)
                # true_probs = list(probs[:,1].cpu().detach().numpy())
                # result.extend(true_probs)
        outputs = np.array(outputs)
        outputs = outputs.reshape((len(docs), len(labels), -1))
        # process outputs
        # 2020-08-24: modified based on transformers pipeline implementation
        if multilabel:
            # softmax over the entailment vs. contradiction dim for each label independently
            entail_contr_logits = outputs[..., [0, -1]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(
                -1, keepdims=True
            scores = scores[..., 1]
        else:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = outputs[..., -1]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(
                -1, keepdims=True
        scores = scores.tolist()
        if include_labels:
            scores = [list(zip(labels, s)) for s in scores]
        if is_str_input:
            scores = scores[0]
        return scores

    Inherited members

  • TorchBase:
  • quantize_model
  •