• R/O
  • HTTP
  • SSH
  • HTTPS

Tags
Keine Tags

Frequently used words (click to add to your profile)

javac++androidlinuxc#windowsobjective-ccocoa誰得qtpythonphprubygameguibathyscaphec計画中(planning stage)翻訳omegatframeworktwitterdomtestvb.netdirectxゲームエンジンbtronarduinopreviewer

Emergent generative agents


File Info

Rev. 356751b59d378fbe6ea336211e0c3588cd08eccb
Größe 2,451 Bytes
Zeit 2023-06-12 03:14:32
Autor Corbin
Log Message

Add WP query access, and fix a few bugs.

Most importantly, make choices on RWKV work. The same technique should
work on LLaMA too.

Content

import importlib.util
import os
import sys

import tokenizers

from gens.base import Gen, Yarn

# Monkey-patch to get rwkv available.
RWKV = "@RWKV@"
RWKV_PATH = os.path.join(RWKV, "bin")

def bare_import(path, module_name):
    file_path = os.path.join(path, module_name + ".py")
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module

rwkv_cpp_shared_library = bare_import(RWKV_PATH, "rwkv_cpp_shared_library")
rwkv_cpp_model = bare_import(RWKV_PATH, "rwkv_cpp_model")
sampling = bare_import(RWKV_PATH, "sampling")

TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json")

# Upstream recommends temp 0.7, top_p 0.5
# TEMPERATURE = 0.8
# TOP_P = 0.8
TEMPERATURE = 0.9
TOP_P = 0.9

class MawrkovGen(Gen):
    model_name = "The Pile"
    model_arch = "RWKV"
    def __init__(self, model_path, max_new_tokens):
        self.max_new_tokens = max_new_tokens
        self.model_size = os.stat(model_path).st_size * 3 // 2
        self.tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)
        self.lib = rwkv_cpp_shared_library.load_rwkv_shared_library()
        self.model = rwkv_cpp_model.RWKVModel(self.lib, model_path)

    # XXX wrong
    def footprint(self): return self.model_size
    def contextLength(self): return 8192
    def tokenize(self, s): return self.tokenizer.encode(s).ids
    def decode(self, ts): return self.tokenizer.decode(ts)

    def fork(self):
        return MawrkovYarn(self.max_new_tokens, self.model, self.tokenizer)

class MawrkovYarn(Yarn):
    logits = state = None
    def __init__(self, max_new_tokens, model, tokenizer):
        self.max_new_tokens = max_new_tokens
        self.model = model
        self.tokenizer = tokenizer

    def feedForward(self, tokens):
        for t in tokens:
            self.logits, self.state = self.model.eval(t, self.state,
                                                      self.state, self.logits)

    def complete(self):
        tokens = []
        for i in range(self.max_new_tokens):
            token = sampling.sample_logits(self.logits, TEMPERATURE, TOP_P)
            if "\n" in self.tokenizer.decode([token]): break
            tokens.append(token)
            self.feedForward([token])
        return tokens

    def force(self, options): return max(options, key=self.logits.__getitem__)