Emergent generative agents
Revision | bb8a464e79a4ccd7e0ef517bbdd226ac6446c5d5 (tree) |
---|---|
Zeit | 2023-06-12 03:14:32 |
Autor | Corbin <cds@corb...> |
Commiter | Corbin |
Bump RWKV and LLaMA dependencies.
@@ -5,11 +5,11 @@ | ||
5 | 5 | "systems": "systems" |
6 | 6 | }, |
7 | 7 | "locked": { |
8 | - "lastModified": 1681202837, | |
9 | - "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", | |
8 | + "lastModified": 1685518550, | |
9 | + "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", | |
10 | 10 | "owner": "numtide", |
11 | 11 | "repo": "flake-utils", |
12 | - "rev": "cfacdce06f30d2b68473a46042957675eebb3401", | |
12 | + "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", | |
13 | 13 | "type": "github" |
14 | 14 | }, |
15 | 15 | "original": { |
@@ -39,11 +39,11 @@ | ||
39 | 39 | "nixpkgs": "nixpkgs" |
40 | 40 | }, |
41 | 41 | "locked": { |
42 | - "lastModified": 1682476640, | |
43 | - "narHash": "sha256-mLVO3T86AaXg3CXJNEjxjAaUGRhB0+8J1yGRV9LN/8M=", | |
42 | + "lastModified": 1685666615, | |
43 | + "narHash": "sha256-+hL7AHJja9EyGbwKOIB4yK/6QsMkSgyi/Dv0YpWl9wA=", | |
44 | 44 | "owner": "MostAwesomeDude", |
45 | 45 | "repo": "llama.cpp", |
46 | - "rev": "8a2a3e5098187a70a3949aa8a9351f2f26478d84", | |
46 | + "rev": "a48b63971f6e4229c00cf23da4cc09881a8f6874", | |
47 | 47 | "type": "github" |
48 | 48 | }, |
49 | 49 | "original": { |
@@ -70,11 +70,11 @@ | ||
70 | 70 | }, |
71 | 71 | "nixpkgs_2": { |
72 | 72 | "locked": { |
73 | - "lastModified": 1682453498, | |
74 | - "narHash": "sha256-WoWiAd7KZt5Eh6n+qojcivaVpnXKqBsVgpixpV2L9CE=", | |
73 | + "lastModified": 1685564631, | |
74 | + "narHash": "sha256-8ywr3AkblY4++3lIVxmrWZFzac7+f32ZEhH/A8pNscI=", | |
75 | 75 | "owner": "NixOS", |
76 | 76 | "repo": "nixpkgs", |
77 | - "rev": "c8018361fa1d1650ee8d4b96294783cf564e8a7f", | |
77 | + "rev": "4f53efe34b3a8877ac923b9350c874e3dcd5dc0a", | |
78 | 78 | "type": "github" |
79 | 79 | }, |
80 | 80 | "original": { |
@@ -13,13 +13,13 @@ | ||
13 | 13 | llama-lib = llama-cpp-lib.packages.${system}.default; |
14 | 14 | llama-cpp-python = pkgs.python310.pkgs.buildPythonPackage rec { |
15 | 15 | pname = "llama-cpp-python"; |
16 | - version = "0.1.38"; | |
16 | + version = "0.1.57"; | |
17 | 17 | |
18 | 18 | src = pkgs.fetchFromGitHub { |
19 | 19 | owner = "abetlen"; |
20 | 20 | repo = pname; |
21 | 21 | rev = "v${version}"; |
22 | - sha256 = "sha256-/Ykndsp6puFxa+FSHNln9M2frS7/sMMBJSNJ/mU/CSI="; | |
22 | + sha256 = "sha256-BrR3N+3KRu96j0MIydyrvFb2BN3COeBPISac+ixq3XM="; | |
23 | 23 | }; |
24 | 24 | format = "setuptools"; |
25 | 25 |
@@ -28,8 +28,13 @@ | ||
28 | 28 | sed -i -e "s,_load_shared_library(_lib_base_name),ctypes.CDLL('${llama-lib}/lib/libllama.so')," llama_cpp/llama_cpp.py |
29 | 29 | ''; |
30 | 30 | |
31 | + # Imports server.app, which needs fancy Starlette packages I'm not | |
32 | + # willing to deal with right now. ~ C. | |
33 | + doCheck = false; | |
34 | + | |
31 | 35 | propagatedBuildInputs = with pkgs.python310.pkgs; [ |
32 | - typing-extensions | |
36 | + numpy typing-extensions | |
37 | + # anyio fastapi uvicorn | |
33 | 38 | ]; |
34 | 39 | }; |
35 | 40 | sentence-transformers = pkgs.python310.pkgs.buildPythonPackage rec { |
@@ -57,12 +62,9 @@ | ||
57 | 62 | |
58 | 63 | src = pkgs.fetchFromGitHub { |
59 | 64 | owner = "saharNooby"; |
60 | - # owner = "iacore"; | |
61 | 65 | repo = "rwkv.cpp"; |
62 | - rev = "c736ef5411606b529d3a74c139ee111ef1a28bb9"; | |
63 | - sha256 = "sha256-zJFmuhyY2kT/WVStBpHSnlmwclXZmVoiFvsurCDHW4E="; | |
64 | - # rev = "ae390c6"; | |
65 | - # sha256 = "sha256-ojDsZgXwd3+E6AGtB/KANGz3Y0W5l9CWGjfhjJEefDQ="; | |
66 | + rev = "363dfb1a061507aee661300fc8e2e153b6e99dc2"; | |
67 | + sha256 = "sha256-HlJmXMXSUNgPJN6TSGnNeeBeY3/9HmRH9Qa2d4jPEu4="; | |
66 | 68 | fetchSubmodules = true; |
67 | 69 | }; |
68 | 70 |
@@ -119,7 +121,7 @@ | ||
119 | 121 | devShells.default = pkgs.mkShell { |
120 | 122 | name = "zirpu-env"; |
121 | 123 | packages = with pkgs; [ |
122 | - git | |
124 | + git gdb | |
123 | 125 | # our Python |
124 | 126 | py |
125 | 127 | # catching Python mistakes |
@@ -17,7 +17,8 @@ from twisted.internet.threads import deferToThread | ||
17 | 17 | from twisted.words.protocols.irc import IRCClient |
18 | 18 | |
19 | 19 | from common import irc_line, Timer, SentenceIndex, breakAt |
20 | -from gens.mawrkov import MawrkovGen, force | |
20 | +from gens.camelid import CamelidGen | |
21 | +from gens.mawrkov import MawrkovGen | |
21 | 22 | from gens.trans import SentenceEmbed |
22 | 23 | |
23 | 24 | build_traits = " + ".join |
@@ -29,9 +30,16 @@ def load_character(path): | ||
29 | 30 | return json.load(handle) |
30 | 31 | |
31 | 32 | MAX_NEW_TOKENS = 128 |
32 | -print("~ Initializing mawrkov adapter…") | |
33 | -model_path = sys.argv[1] | |
34 | -gen = MawrkovGen(model_path, MAX_NEW_TOKENS) | |
33 | +gens = { | |
34 | + "llama": CamelidGen, | |
35 | + "rwkv": MawrkovGen, | |
36 | +} | |
37 | +model_cls = sys.argv[1] | |
38 | +if model_cls not in gens: | |
39 | + raise ValueError("must be one of %r" % tuple(gens.keys())) | |
40 | +print("~ Initializing adapter:", model_cls) | |
41 | +model_path = sys.argv[2] | |
42 | +gen = gens[model_cls](model_path, MAX_NEW_TOKENS) | |
35 | 43 | # Need to protect per-gen data structures in C. |
36 | 44 | genLock = Lock() |
37 | 45 | GiB = 1024 ** 3 |
@@ -50,7 +58,6 @@ prologues = { | ||
50 | 58 | |
51 | 59 | class Mind: |
52 | 60 | currentTag = None |
53 | - logits = state = None | |
54 | 61 | |
55 | 62 | def __init__(self, yarn, name): |
56 | 63 | self.yarn = yarn |
@@ -103,7 +110,7 @@ class Mind: | ||
103 | 110 | # Breaking the newline invariant... |
104 | 111 | d.addCallback(lambda _: deferToThread(self.writeRaw, prefix)) |
105 | 112 | # ...so that this inference happens before the newline... |
106 | - d.addCallback(lambda _: force(tokens, self.logits)) | |
113 | + d.addCallback(lambda _: self.yarn.force(tokens)) | |
107 | 114 | # XXX decode should be on gen ABC |
108 | 115 | d.addCallback(lambda t: gen.tokenizer.decode([t])) |
109 | 116 |
@@ -279,7 +286,7 @@ def go(): | ||
279 | 286 | clock = Clock() |
280 | 287 | LoopingCall(clock.go).start(60 * 30, now=True) |
281 | 288 | |
282 | - for logpath in sys.argv[2:]: | |
289 | + for logpath in sys.argv[3:]: | |
283 | 290 | character = load_character(logpath) |
284 | 291 | title = character["title"] |
285 | 292 | firstStatement = f"I am {title}." |
@@ -8,41 +8,59 @@ class CamelidGen(Gen): | ||
8 | 8 | model_name = "LLaMA?" |
9 | 9 | model_arch = "LLaMA" |
10 | 10 | def __init__(self, model_path, max_new_tokens): |
11 | - self.llama = Llama(model_path, n_ctx=1024) | |
11 | + self.llama = Llama(model_path) | |
12 | 12 | self.model_size = os.stat(model_path).st_size * 3 // 2 |
13 | 13 | self.max_new_tokens = max_new_tokens |
14 | 14 | |
15 | 15 | def footprint(self): return self.model_size |
16 | 16 | def contextLength(self): return llama_cpp.llama_n_ctx(self.llama.ctx) |
17 | 17 | |
18 | - # XXX doesn't work? | |
19 | 18 | def tokenize(self, s): return self.llama.tokenize(s.encode("utf-8")) |
20 | - def decode(self, ts): return self.llama.detokenize(ts) | |
19 | + def decode(self, ts): return self.llama.detokenize(ts).decode("utf-8") | |
21 | 20 | |
22 | 21 | def fork(self): |
23 | 22 | return CamelidYarn(self.max_new_tokens, self.llama, self.llama.save_state()) |
24 | 23 | |
24 | +yarn_cache = [None] | |
25 | + | |
25 | 26 | class CamelidYarn(Yarn): |
26 | 27 | def __init__(self, max_new_tokens, llama, state): |
27 | 28 | self.max_new_tokens = max_new_tokens |
28 | 29 | self.llama = llama |
29 | 30 | self.state = state |
30 | 31 | |
32 | + def activate(self): | |
33 | + if yarn_cache[0] is not self: | |
34 | + if yarn_cache[0]: yarn_cache[0].deactivate() | |
35 | + yarn_cache[0] = self | |
36 | + self.llama.load_state(self.state) | |
37 | + | |
38 | + def deactivate(self): self.state = self.llama.save_state() | |
39 | + | |
31 | 40 | def feedForward(self, tokens): |
32 | - self.llama.load_state(self.state) | |
41 | + self.activate() | |
33 | 42 | self.llama.eval(tokens) |
34 | - self.state = self.llama.save_state() | |
35 | 43 | |
36 | 44 | def complete(self): |
37 | - self.llama.load_state(self.state) | |
45 | + self.activate() | |
38 | 46 | tokens = [] |
39 | 47 | for _ in range(self.max_new_tokens): |
40 | 48 | token = self.llama.sample() |
41 | - if "\n" in self.llama.detokenize([token]): break | |
49 | + if b"\n" in self.llama.detokenize([token]): break | |
42 | 50 | tokens.append(token) |
43 | - self.state = self.llama.save_state() | |
44 | 51 | return tokens |
45 | 52 | |
53 | + def force(self, options): | |
54 | + self.activate() | |
55 | + return self.llama.sample(logits_processor=Force(options)) | |
56 | + | |
57 | +class Force: | |
58 | + def __init__(self, options): self.options = options | |
59 | + def __call__(self, input_ids, scores): | |
60 | + # +10 is hopefully not too much. | |
61 | + for option in self.options: scores[option] += 10 | |
62 | + return scores | |
63 | + | |
46 | 64 | class CamelidEmbed: |
47 | 65 | def __init__(self, model_path): |
48 | 66 | self.llama = Llama(model_path, embedding=True) |
@@ -25,8 +25,10 @@ sampling = bare_import(RWKV_PATH, "sampling") | ||
25 | 25 | TOKENIZER_PATH = os.path.join(RWKV, "share", "20B_tokenizer.json") |
26 | 26 | |
27 | 27 | # Upstream recommends temp 0.7, top_p 0.5 |
28 | -TEMPERATURE = 0.8 | |
29 | -TOP_P = 0.8 | |
28 | +# TEMPERATURE = 0.8 | |
29 | +# TOP_P = 0.8 | |
30 | +TEMPERATURE = 0.9 | |
31 | +TOP_P = 0.9 | |
30 | 32 | |
31 | 33 | class MawrkovGen(Gen): |
32 | 34 | model_name = "The Pile" |
@@ -68,7 +70,7 @@ class MawrkovYarn(Yarn): | ||
68 | 70 | self.feedForward([token]) |
69 | 71 | return tokens |
70 | 72 | |
71 | -def force(options, logits): | |
72 | - # +10 is hopefully not too much. | |
73 | - biases = {opt: 10 for opt in options} | |
74 | - return sampling.sample_logits(logits, TEMPERATURE, TOP_P, biases) | |
73 | + def force(self, options): | |
74 | + # +10 is hopefully not too much. | |
75 | + biases = {opt: 10 for opt in options} | |
76 | + return sampling.sample_logits(self.logits, TEMPERATURE, TOP_P, biases) |