Emergent generative agents
Revision | 5b55807030216ac952c2c35395ff243cba572994 (tree) |
---|---|
Zeit | 2023-05-20 04:04:24 |
Autor | Corbin <cds@corb...> |
Commiter | Corbin |
Prototype choices, create subtagging for messages.
These go hand-in-hand; we want to be able to choose to not just reply on
IRC, but reply on a specific channel, so we need to use the channel as
subtag data.
The choice mechanism is clunky; I would very much like to just use
ReLLM, and maybe that'll be important in the future if/when we migrate
back to HF with their brand-new RWKV integration. They don't support
quantized models, though...
@@ -6,18 +6,19 @@ import json | ||
6 | 6 | import os.path |
7 | 7 | import re |
8 | 8 | import random |
9 | +from string import ascii_uppercase | |
9 | 10 | import sys |
10 | 11 | from threading import Lock |
11 | 12 | |
12 | 13 | from twisted.internet import reactor |
13 | 14 | from twisted.internet.defer import DeferredLock, succeed |
14 | 15 | from twisted.internet.protocol import ClientFactory |
15 | -from twisted.internet.task import LoopingCall | |
16 | +from twisted.internet.task import LoopingCall, deferLater | |
16 | 17 | from twisted.internet.threads import deferToThread |
17 | 18 | from twisted.words.protocols.irc import IRCClient |
18 | 19 | |
19 | 20 | from common import irc_line, Timer, SentenceIndex, breakAt |
20 | -from gens.mawrkov import MawrkovGen | |
21 | +from gens.mawrkov import MawrkovGen, force | |
21 | 22 | from gens.trans import SentenceEmbed |
22 | 23 | |
23 | 24 | build_traits = " + ".join |
@@ -43,55 +44,85 @@ prologues = { | ||
43 | 44 | "clock": "Then I checked the time:", |
44 | 45 | "lbi": "Then I tried to interpret what just happened:", |
45 | 46 | "thoughts": "Then I thought to myself:", |
46 | - "irc": "Then I chatted on IRC:", | |
47 | + "choice": "Then I chose what to do next:", | |
48 | + "irc": "Then I chatted on IRC, channel {subtag}:", | |
47 | 49 | } |
48 | 50 | |
49 | 51 | class Mind: |
50 | 52 | currentTag = None |
51 | 53 | logits = state = None |
52 | 54 | |
53 | - def __init__(self): self.lock = DeferredLock() | |
55 | + def __init__(self, name): | |
56 | + # Exclusively for debugging! | |
57 | + self.name = name | |
58 | + self.lock = DeferredLock() | |
54 | 59 | |
55 | - def switchTag(self, tag): | |
56 | - if tag == self.currentTag: return succeed(None) | |
60 | + def switchTag(self, tag, subtag): | |
61 | + if (tag, subtag) == self.currentTag: return succeed(None) | |
57 | 62 | else: |
58 | - self.currentTag = tag | |
63 | + self.currentTag = tag, subtag | |
59 | 64 | # Double newlines are added here. |
60 | - return deferToThread(self.write, "\n" + prologues[tag]) | |
65 | + return deferToThread(self.write, "\n" + | |
66 | + prologues[tag].format(subtag=subtag)) | |
61 | 67 | |
62 | - def overhear(self, tag, s): | |
68 | + def overhear(self, tag, subtag, s): | |
63 | 69 | def cb(): |
64 | - d = self.switchTag(tag) | |
70 | + d = self.switchTag(tag, subtag) | |
65 | 71 | d.addCallback(lambda _: deferToThread(self.write, s)) |
66 | 72 | return d |
67 | 73 | return self.lock.run(cb) |
68 | 74 | |
69 | - def write(self, s): | |
75 | + def writeRaw(self, s): | |
70 | 76 | with genLock: |
71 | - # Newlines are added here. | |
72 | - self.logits, self.state = gen.feedForward(gen.tokenize(s + "\n"), | |
77 | + self.logits, self.state = gen.feedForward(gen.tokenize(s), | |
73 | 78 | self.logits, self.state) |
74 | - print(s) | |
79 | + print(self.name, "~~", s.strip()) | |
80 | + | |
81 | + # Newlines are added here. | |
82 | + def write(self, s): self.writeRaw(s + "\n") | |
75 | 83 | |
76 | 84 | def complete(self, s): |
77 | 85 | with genLock: |
78 | 86 | completion, self.logits, self.state = gen.complete(s, self.logits, self.state) |
79 | - print("«", completion, "»") | |
87 | + print(self.name, "«", completion.strip(), "»") | |
80 | 88 | return completion |
81 | 89 | |
82 | - def infer(self, tag, prefix): | |
90 | + def infer(self, tag, subtag, prefix): | |
91 | + def cb(): | |
92 | + d = self.switchTag(tag, subtag) | |
93 | + d.addCallback(lambda _: print(self.name, "~?", prefix) | |
94 | + or deferToThread(self.complete, prefix)) | |
95 | + return d | |
96 | + return self.lock.run(cb) | |
97 | + | |
98 | + def forceChoice(self, tag, subtag, prefix, options): | |
99 | + tokens = [gen.tokenize(opt)[0] for opt in options] | |
83 | 100 | def cb(): |
84 | - print(prefix) | |
85 | - d = self.switchTag(tag) | |
86 | - d.addCallback(lambda _: deferToThread(self.complete, prefix)) | |
101 | + d = self.switchTag(tag, subtag) | |
102 | + # Breaking the newline invariant... | |
103 | + d.addCallback(lambda _: print(self.name, "~!", prefix) | |
104 | + or deferToThread(self.writeRaw, prefix)) | |
105 | + # ...so that this inference happens before the newline... | |
106 | + d.addCallback(lambda _: force(tokens, self.logits)) | |
107 | + d.addCallback(lambda t: gen.tokenizer.decode([t])) | |
108 | + | |
109 | + @d.addCallback | |
110 | + def cb2(s): | |
111 | + # ...and now the invariant must be restored... | |
112 | + d2 = deferToThread(self.write, s) | |
113 | + # ...but we want to tell the caller what was chosen! | |
114 | + d2.addCallback(lambda _: s) | |
115 | + return d2 | |
87 | 116 | return d |
88 | 117 | return self.lock.run(cb) |
89 | 118 | |
90 | 119 | |
91 | 120 | class Agent: |
92 | 121 | listeners = () |
122 | + subtag = None | |
93 | 123 | def broadcast(self, s): |
94 | - for listener in self.listeners: listener.overhear(self.tag, s) | |
124 | + for listener in self.listeners: | |
125 | + listener.overhear(self.tag, self.subtag, s) | |
95 | 126 | |
96 | 127 | class Clock(Agent): |
97 | 128 | tag = "clock" |
@@ -102,45 +133,77 @@ class LeftBrainInterpreter(Agent): | ||
102 | 133 | tag = "lbi" |
103 | 134 | events = 0 |
104 | 135 | def __init__(self, mind): self.mind = mind |
105 | - def overhear(self, tag, s): | |
136 | + def overhear(self, tag, subtag, s): | |
137 | + # XXX should skip thoughts, maybe? | |
106 | 138 | self.events += 1 |
107 | - if self.events >= 10: | |
139 | + if self.events >= 20: | |
108 | 140 | self.events = 0 |
109 | - return self.mind.infer(self.tag, "") | |
141 | + return self.mind.infer(self.tag, self.subtag, "") | |
142 | + | |
143 | +choices = { | |
144 | + "irc": "Chat on IRC, channel {subtag}", | |
145 | + "thoughts": "Think to myself for a moment", | |
146 | +} | |
147 | +def choicify(options): | |
148 | + return " ".join(f"({c}) {o}" for c, o in zip(ascii_uppercase, options)) | |
149 | + | |
150 | +class ChoiceMaker(Agent): | |
151 | + tag = "choice" | |
152 | + def __init__(self, mind): | |
153 | + self.mind = mind | |
154 | + self.possibilities = set() | |
155 | + self.actions = {} | |
156 | + def overhear(self, tag, subtag, s): self.possibilities.add((tag, subtag)) | |
157 | + def dispatch(self, tag, subtag): | |
158 | + if tag in self.actions: return self.actions[tag](subtag) | |
159 | + def idle(self): | |
160 | + possibilities = [(tag, subtag) | |
161 | + for tag, subtag in self.possibilities | |
162 | + if tag in choices] | |
163 | + if len(possibilities) == 0: return | |
164 | + elif len(possibilities) == 1: return self.dispatch(*possibilities[0]) | |
165 | + | |
166 | + options = [choices[tag].format(subtag=subtag) | |
167 | + for tag, subtag in possibilities] | |
168 | + prompt = choicify(options.values()) | |
169 | + self.broadcast(prompt) | |
170 | + d = self.mind.forceChoice(self.tag, self.subtag, "My choice: ", | |
171 | + ascii_uppercase[:len(options)]) | |
172 | + d.addCallback(lambda s: self.dispatch(*possibilities[chr(ord(s[0]) - | |
173 | + ord('A'))])) | |
174 | + return d | |
110 | 175 | |
111 | 176 | class ChainOfThoughts(Agent): |
112 | 177 | tag = "thoughts" |
113 | - def __init__(self, mind, index, seed): | |
178 | + def __init__(self, mind, index, seed, idle): | |
114 | 179 | self.mind = mind |
115 | 180 | self.index = index |
116 | 181 | self.recentThoughts = deque([seed], maxlen=5) |
182 | + self.idle = idle | |
117 | 183 | |
118 | - def go(self): | |
119 | - cb = self.reflect if random.choice([0, 1]) else self.cogitate | |
120 | - return cb() | |
184 | + def go(self): return random.choice([self.reflect, self.cogitate])() | |
121 | 185 | |
122 | - def overhear(self, tag, s): | |
123 | - self.addRelatedThoughts(tag + ": " + s) | |
186 | + def overhear(self, tag, subtag, s): | |
187 | + prefix = f"{tag} ({subtag}): " if subtag else tag + ": " | |
188 | + self.addRelatedThoughts(prefix + s) | |
124 | 189 | |
125 | 190 | def addRelatedThoughts(self, s): |
126 | 191 | thoughts = self.index.search(s, 2) |
192 | + new = 0 | |
127 | 193 | for thought in thoughts: |
128 | 194 | if thought not in self.recentThoughts: |
195 | + new += 1 | |
129 | 196 | self.recentThoughts.append(thought) |
130 | 197 | self.broadcast(thought) |
198 | + return new | |
131 | 199 | |
132 | - def cogitate(self): self.addRelatedThoughts(self.recentThoughts[-1]) | |
200 | + def cogitate(self): | |
201 | + new = self.addRelatedThoughts(self.recentThoughts[-1]) | |
202 | + if not new: return deferLater(reactor, 0.0, self.idle) | |
133 | 203 | def reflect(self): |
134 | - d = self.mind.infer(self.tag, "") | |
135 | - | |
136 | - @d.addCallback | |
137 | - def cb(s): | |
138 | - if not s.strip(): | |
139 | - self.broadcast(random.choice([ | |
140 | - "Head empty; no thoughts.", | |
141 | - "So bored.", | |
142 | - "Zoned out.", | |
143 | - ])) | |
204 | + d = self.mind.infer(self.tag, self.subtag, "") | |
205 | + d.addCallback(lambda s: s.strip() or self.idle()) | |
206 | + return d | |
144 | 207 | |
145 | 208 | |
146 | 209 | IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2} <") |
@@ -155,40 +218,45 @@ class IRCAgent(Agent, IRCClient): | ||
155 | 218 | self.nickname = title_to_nick(title) |
156 | 219 | self.startingChannels = startingChannels |
157 | 220 | |
221 | + def broadcastTagged(self, subtag, s): | |
222 | + for listener in self.listeners: listener.overhear(self.tag, subtag, s) | |
223 | + | |
158 | 224 | def prefix(self, channel): |
159 | 225 | return f"{datetime.now():%H:%M:%S} {channel} <{self.nickname}>" |
160 | 226 | |
161 | 227 | def userJoined(self, user, channel): |
162 | - self.broadcast(f"{user} joins {channel}") | |
228 | + self.broadcastTagged(channel, f"{user} joins {channel}") | |
163 | 229 | |
164 | 230 | def userLeft(self, user, channel): |
165 | - self.broadcast(f"{user} leaves {channel}") | |
231 | + self.broadcastTagged(channel, f"{user} leaves {channel}") | |
166 | 232 | |
167 | 233 | def userQuit(self, user, channel): |
168 | - self.broadcast(f"{user} quits {channel}") | |
234 | + self.broadcastTagged(channel, f"{user} quits {channel}") | |
169 | 235 | |
170 | 236 | def topicUpdated(self, user, channel, topic): |
171 | - self.broadcast(f"Topic for {channel} is now: {topic}") | |
237 | + self.broadcastTagged(channel, f"Topic for {channel} is now: {topic}") | |
172 | 238 | |
173 | 239 | def signedOn(self): |
174 | 240 | for channel in self.startingChannels: self.join(channel) |
175 | 241 | |
176 | 242 | def privmsg(self, user, channel, line): |
177 | 243 | user = user.split("!", 1)[0] |
178 | - self.broadcast(irc_line(datetime.now(), channel, user, line)) | |
179 | - if self.nickname in line: | |
180 | - d = self.mind.infer("irc", self.prefix(channel)) | |
244 | + self.broadcastTagged(channel, irc_line(datetime.now(), channel, user, line)) | |
245 | + # XXX should eventually be not needed! | |
246 | + # if self.nickname in line: self.speakInChannel(channel) | |
181 | 247 | |
182 | - @d.addCallback | |
183 | - def cb(s): | |
184 | - line = breakIRCLine(s).strip() | |
185 | - self.msg(channel, line) | |
248 | + def speakInChannel(self, channel): | |
249 | + d = self.mind.infer("irc", channel, self.prefix(channel)) | |
250 | + | |
251 | + @d.addCallback | |
252 | + def cb(s): self.msg(channel, breakIRCLine(s).strip()) | |
186 | 253 | |
187 | 254 | class IRCFactory(ClientFactory): |
188 | 255 | protocol = IRCAgent |
189 | - def __init__(self, mind, listeners, title, startingChannels): | |
256 | + def __init__(self, mind, cm, listeners, title, startingChannels): | |
190 | 257 | super(IRCFactory, self).__init__() |
191 | 258 | self.mind = mind |
259 | + self.cm = cm | |
192 | 260 | self.listeners = listeners |
193 | 261 | self.title = title |
194 | 262 | self.startingChannels = startingChannels |
@@ -196,12 +264,13 @@ class IRCFactory(ClientFactory): | ||
196 | 264 | protocol = self.protocol(self.mind, self.title, self.startingChannels) |
197 | 265 | protocol.factory = self |
198 | 266 | protocol.listeners = self.listeners |
267 | + self.cm.actions["irc"] = protocol.speakInChannel | |
199 | 268 | return protocol |
200 | 269 | |
201 | 270 | def go(): |
202 | 271 | print("~ Starting tasks…") |
203 | 272 | clock = Clock() |
204 | - LoopingCall(clock.go).start(300.0, now=False) | |
273 | + LoopingCall(clock.go).start(60 * 10, now=False) | |
205 | 274 | |
206 | 275 | for logpath in sys.argv[2:]: |
207 | 276 | character = load_character(logpath) |
@@ -210,19 +279,21 @@ def go(): | ||
210 | 279 | thoughtPath = os.path.join(logpath, "thoughts.txt") |
211 | 280 | thoughtIndex = SentenceIndex.fromPath(thoughtPath, embedder) |
212 | 281 | |
213 | - mind = Mind() | |
282 | + mind = Mind(title_to_nick(title)) | |
214 | 283 | with Timer("initial warmup"): |
215 | 284 | mind.logits, mind.state = gen.feedForward(gen.tokenize(firstStatement), None, None) |
216 | 285 | |
217 | 286 | lbi = LeftBrainInterpreter(mind) |
218 | 287 | clock.listeners += lbi, mind |
219 | 288 | |
220 | - thoughts = ChainOfThoughts(mind, thoughtIndex, firstStatement) | |
289 | + cm = ChoiceMaker(mind) | |
290 | + | |
291 | + thoughts = ChainOfThoughts(mind, thoughtIndex, firstStatement, cm.idle) | |
221 | 292 | thoughts.listeners = lbi, mind |
222 | - LoopingCall(thoughts.go).start(120.0, now=False) | |
293 | + LoopingCall(thoughts.go).start(60, now=False) | |
223 | 294 | |
224 | 295 | print("~ Thought index:", thoughtIndex.size(), "thoughts") |
225 | - factory = IRCFactory(mind, (thoughts, lbi, mind), | |
296 | + factory = IRCFactory(mind, cm, (thoughts, lbi, cm, mind), | |
226 | 297 | title, |
227 | 298 | character["startingChannels"]) |
228 | 299 | print("~ Connecting factory for:", title) |
@@ -27,8 +27,7 @@ TEMPERATURE = 0.8 | ||
27 | 27 | TOP_P = 0.8 |
28 | 28 | |
29 | 29 | class MawrkovGen: |
30 | - # XXX might be wrong | |
31 | - model_name = "The Pile (14B params, 4-bit quantized)" | |
30 | + model_name = "The Pile" | |
32 | 31 | model_arch = "RWKV" |
33 | 32 | def __init__(self, model_path, max_new_tokens): |
34 | 33 | self.max_new_tokens = max_new_tokens |
@@ -57,3 +56,8 @@ class MawrkovGen: | ||
57 | 56 | logits, state = self.feedForward([token], logits, state) |
58 | 57 | if "\n" in self.tokenizer.decode([token]): break |
59 | 58 | return self.tokenizer.decode(tokens).split("\n", 1)[0], logits, state |
59 | + | |
60 | +def force(options, logits): | |
61 | + # +10 is hopefully not too much. | |
62 | + biases = {opt: 10 for opt in options} | |
63 | + return sampling.sample_logits(logits, TEMPERATURE, TOP_P, biases) |