Emergent generative agents
Revision | 356751b59d378fbe6ea336211e0c3588cd08eccb (tree) |
---|---|
Zeit | 2023-06-12 03:14:32 |
Autor | Corbin <cds@corb...> |
Commiter | Corbin |
Add WP query access, and fix a few bugs.
Most importantly, make choices on RWKV work. The same technique should
work on LLaMA too.
@@ -52,9 +52,22 @@ | ||
52 | 52 | |
53 | 53 | doCheck = false; |
54 | 54 | }; |
55 | + mediawiki = pkgs.python310.pkgs.buildPythonPackage rec { | |
56 | + pname = "pymediawiki"; | |
57 | + version = "0.7.2"; | |
58 | + | |
59 | + src = pkgs.fetchPypi { | |
60 | + inherit pname version; | |
61 | + sha256 = "sha256-4KjtKSnWBRyZgHrnUCwdKwSx9FftkQ2+vJuptfnWqUI="; | |
62 | + }; | |
63 | + | |
64 | + propagatedBuildInputs = with pkgs.python310.pkgs; [ | |
65 | + beautifulsoup4 requests | |
66 | + ]; | |
67 | + }; | |
55 | 68 | py = pkgs.python310.withPackages (ps: with ps; [ |
56 | 69 | faiss llama-cpp-python sentence-transformers tokenizers transformers torch |
57 | - twisted | |
70 | + twisted mediawiki | |
58 | 71 | ]); |
59 | 72 | rwkv = pkgs.stdenv.mkDerivation { |
60 | 73 | name = "rwkv.cpp"; |
@@ -4,6 +4,7 @@ from collections import deque | ||
4 | 4 | from datetime import datetime |
5 | 5 | import json |
6 | 6 | import os.path |
7 | +import random | |
7 | 8 | import re |
8 | 9 | from string import ascii_uppercase |
9 | 10 | import sys |
@@ -16,6 +17,9 @@ from twisted.internet.task import LoopingCall, deferLater | ||
16 | 17 | from twisted.internet.threads import deferToThread |
17 | 18 | from twisted.words.protocols.irc import IRCClient |
18 | 19 | |
20 | +from mediawiki import MediaWiki | |
21 | +from mediawiki.exceptions import DisambiguationError, MediaWikiException, PageError | |
22 | + | |
19 | 23 | from common import irc_line, Timer, SentenceIndex, breakAt |
20 | 24 | from gens.camelid import CamelidGen |
21 | 25 | from gens.mawrkov import MawrkovGen |
@@ -54,6 +58,7 @@ prologues = { | ||
54 | 58 | "thoughts": "Then I thought to myself:", |
55 | 59 | "choice": "Then I chose what to do next:", |
56 | 60 | "irc": "Then I chatted on IRC, channel {subtag}:", |
61 | + "wp": "Then I searched Wikipedia by title and read a summary:", | |
57 | 62 | } |
58 | 63 | |
59 | 64 | class Mind: |
@@ -155,6 +160,7 @@ uppercase = ascii_uppercase[1:] | ||
155 | 160 | choices = { |
156 | 161 | "irc": "Chat on IRC, channel {subtag}", |
157 | 162 | "thoughts": "Think to myself for a moment", |
163 | + "wp": "Search Wikipedia", | |
158 | 164 | } |
159 | 165 | def choicify(options): |
160 | 166 | return " ".join(f"({c}) {o}" for c, o in zip(uppercase, options)) |
@@ -166,7 +172,10 @@ class ChoiceMaker(Agent): | ||
166 | 172 | self.mind = mind |
167 | 173 | self.possibilities = set() |
168 | 174 | self.actions = {} |
169 | - def overhear(self, tag, subtag, s): self.possibilities.add((tag, subtag)) | |
175 | + def overhear(self, tag, subtag, s): | |
176 | + # XXX hack | |
177 | + if tag == "irc" and subtag is None: return | |
178 | + self.possibilities.add((tag, subtag)) | |
170 | 179 | def dispatch(self, tag, subtag): |
171 | 180 | if tag in self.actions: return self.actions[tag](subtag) |
172 | 181 | def idle(self): |
@@ -176,11 +185,19 @@ class ChoiceMaker(Agent): | ||
176 | 185 | if len(possibilities) == 0: return |
177 | 186 | elif len(possibilities) == 1: return self.dispatch(*possibilities[0]) |
178 | 187 | |
188 | + random.shuffle(possibilities) | |
179 | 189 | prompt = choicify([choices[tag].format(subtag=subtag) |
180 | 190 | for tag, subtag in possibilities]) |
181 | 191 | d = self.mind.forceChoice(self.tag, self.subtag, prompt + " My choice: ", |
182 | 192 | uppercase[:len(possibilities)]) |
183 | - d.addCallback(lambda s: self.dispatch(*possibilities[indexify(s)])) | |
193 | + | |
194 | + @d.addCallback | |
195 | + def cb(s): | |
196 | + index = indexify(s) | |
197 | + try: return self.dispatch(*possibilities[index]) | |
198 | + except IndexError: | |
199 | + self.broadcast("But that was an invalid choice, so I had to choose again.") | |
200 | + return deferLater(reactor, 0.0, self.idle) | |
184 | 201 | return d |
185 | 202 | |
186 | 203 | class ChainOfThoughts(Agent): |
@@ -198,7 +215,7 @@ class Recall(Agent): | ||
198 | 215 | tag = "memories" |
199 | 216 | def __init__(self, index, seed, idle): |
200 | 217 | self.index = index |
201 | - self.recentThoughts = deque([seed], maxlen=5) | |
218 | + self.recentThoughts = deque([seed], maxlen=25) | |
202 | 219 | self.idle = idle |
203 | 220 | |
204 | 221 | def go(self): |
@@ -216,11 +233,11 @@ class Recall(Agent): | ||
216 | 233 | if thought not in self.recentThoughts: |
217 | 234 | new += 1 |
218 | 235 | self.recentThoughts.append(thought) |
219 | - self.broadcast(thought) | |
236 | + deferLater(reactor, 0.0, self.broadcast, thought) | |
220 | 237 | return new |
221 | 238 | |
222 | 239 | |
223 | -IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2} <") | |
240 | +IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2}( #[-a-z]+)? <") | |
224 | 241 | def breakIRCLine(line): |
225 | 242 | return IRC_LINE_HEAD.split(breakAt(line.strip(), "\n"), maxsplit=1)[0] |
226 | 243 |
@@ -244,8 +261,7 @@ class IRCAgent(Agent, IRCClient): | ||
244 | 261 | def userLeft(self, user, channel): |
245 | 262 | self.broadcastTagged(channel, f"{user} leaves {channel}") |
246 | 263 | |
247 | - def userQuit(self, user, channel): | |
248 | - self.broadcastTagged(channel, f"{user} quits {channel}") | |
264 | + def userQuit(self, user, reason): self.broadcast(f"{user} quit") | |
249 | 265 | |
250 | 266 | def topicUpdated(self, user, channel, topic): |
251 | 267 | self.broadcastTagged(channel, f"Topic for {channel} is now: {topic}") |
@@ -281,6 +297,36 @@ class IRCFactory(ClientFactory): | ||
281 | 297 | self.cm.actions["irc"] = protocol.speakInChannel |
282 | 298 | return protocol |
283 | 299 | |
300 | +class ReadWP(Agent): | |
301 | + tag = "wp" | |
302 | + def __init__(self, mind): | |
303 | + self.wp = MediaWiki() | |
304 | + self.mind = mind | |
305 | + | |
306 | + def go(self): | |
307 | + d = self.mind.infer(self.tag, self.subtag, "Title search:") | |
308 | + | |
309 | + @d.addCallback | |
310 | + def cb(title): | |
311 | + title = title.strip() | |
312 | + if not title: | |
313 | + self.broadcast("No query provided; trying again.") | |
314 | + return deferLater(reactor, 0.0, self.go) | |
315 | + try: | |
316 | + page = self.wp.page(title) | |
317 | + summary = page.summary[:200] | |
318 | + if len(page.summary) > 200: summary += "..." | |
319 | + self.broadcast("Title: " + page.title) | |
320 | + self.broadcast("Summary: " + summary) | |
321 | + except DisambiguationError as de: | |
322 | + self.broadcast("Ambiguous query. Possible titles: " + | |
323 | + ", ".join(de.options)) | |
324 | + except MediaWikiException: | |
325 | + self.broadcast("Search query was too long.") | |
326 | + except PageError: self.broadcast("No article with that title.") | |
327 | + | |
328 | + return d | |
329 | + | |
284 | 330 | def go(): |
285 | 331 | print("~ Starting tasks…") |
286 | 332 | clock = Clock() |
@@ -305,11 +351,16 @@ def go(): | ||
305 | 351 | |
306 | 352 | memories = Recall(thoughtIndex, firstStatement, cm.idle) |
307 | 353 | memories.listeners = lbi, mind |
308 | - LoopingCall(memories.go).start(60 * 2, now=True) | |
354 | + LoopingCall(memories.go).start(60 * 5, now=True) | |
309 | 355 | |
310 | 356 | thoughts = ChainOfThoughts(mind, cm.idle) |
311 | 357 | thoughts.listeners = memories, lbi, mind |
312 | - LoopingCall(thoughts.go).start(60 * 2, now=False) | |
358 | + LoopingCall(thoughts.go).start(60 * 5, now=False) | |
359 | + | |
360 | + readwp = ReadWP(mind) | |
361 | + readwp.listeners = cm, lbi, mind | |
362 | + LoopingCall(readwp.go).start(60 * 30, now=True) | |
363 | + cm.actions["wp"] = lambda subtag: readwp.go() | |
313 | 364 | |
314 | 365 | factory = IRCFactory(mind, cm, (memories, lbi, cm, mind), |
315 | 366 | title, |
@@ -70,7 +70,4 @@ class MawrkovYarn(Yarn): | ||
70 | 70 | self.feedForward([token]) |
71 | 71 | return tokens |
72 | 72 | |
73 | - def force(self, options): | |
74 | - # +10 is hopefully not too much. | |
75 | - biases = {opt: 10 for opt in options} | |
76 | - return sampling.sample_logits(self.logits, TEMPERATURE, TOP_P, biases) | |
73 | + def force(self, options): return max(options, key=self.logits.__getitem__) |